diff --git a/include/operations/blas2_trees.h b/include/operations/blas2_trees.h index 34937283e..9dbbedebb 100644 --- a/include/operations/blas2_trees.h +++ b/include/operations/blas2_trees.h @@ -502,6 +502,64 @@ make_trsv(vector_t &lhs_, matrix_t &matrix_, sync_t &sync_) { subgroups, is_upper, is_transposed, is_unit>(lhs_, matrix_, k_, sync_); } +/** + * @struct Ger + * @brief Tree node representing the sum of scalar-vector-vector product with a + * matrix, i.e., it computes lhs_ such that + * + * lhs_ = scalar_ * ( rhs_1_ * rhs_2_^t ) + lhs_ + * + * @param lhs_ input/output matrix + * @param scalar_ value for scaling vector product + * @param rhs_1_ first input vector + * @param rhs_2_ second input vector + * @param nRowsWG_ rows of the workgroup tile + * @param nColsWG_ cols of the workgroup tile + * @param nWG_row_ number of tiles per global size row + * @param nWG_col_ number of tiles per global size column + * + */ +template +struct Ger { + using value_t = typename rhs_2_t::value_t; + using index_t = typename rhs_2_t::index_t; + + lhs_t lhs_; + value_t scalar_; + rhs_1_t rhs_1_; + rhs_2_t rhs_2_; + index_t nRowsWG_; + index_t nColsWG_; + index_t nWG_row_; + index_t nWG_col_; + + Ger(lhs_t &_l, value_t _scl, rhs_1_t &_r1, rhs_2_t &_r2, index_t &_nRowsWG, + index_t &_nColsWG, index_t &_nWG_row, index_t &_nWG_col); + + index_t get_size() const; + bool valid_thread(cl::sycl::nd_item<1> ndItem) const; + value_t eval(index_t i); + value_t eval(cl::sycl::nd_item<1> ndItem); + template + value_t eval(sharedT shrMem, cl::sycl::nd_item<1> ndItem); + void bind(cl::sycl::handler &h); + void adjust_access_displacement(); +}; + +/*! + @brief Generator/factory for GER trees. + */ +template +Ger make_ger(lhs_t &lhs_, + typename lhs_t::value_t scalar_, + rhs_1_t &rhs_1_, rhs_2_t &rhs_2_, + typename rhs_2_t::index_t nRowsWG_, + typename rhs_2_t::index_t nColsWG_, + typename rhs_2_t::index_t nWG_row_, + typename rhs_2_t::index_t nWG_col_) { + return Ger(lhs_, scalar_, rhs_1_, rhs_2_, nRowsWG_, + nColsWG_, nWG_row_, nWG_col_); +} /**** GER BY ROWS M ROWS x N BLOCK USING PROPERLY THE SHARED MEMORY ****/ // template diff --git a/src/interface/blas2_interface.hpp b/src/interface/blas2_interface.hpp index 71dbee066..14c5fad24 100644 --- a/src/interface/blas2_interface.hpp +++ b/src/interface/blas2_interface.hpp @@ -878,7 +878,7 @@ typename sb_handle_t::event_t _ger_impl( container_t0 _vx, increment_t _incx, container_t1 _vy, increment_t _incy, container_t2 _mA, index_t _lda, const typename sb_handle_t::event_t& _dependencies, index_t _localSize = 0, - index_t _scratchPadSize = 0, index_t _nRowsWG = 0, index_t _nColsWG = 0) { + bool _useLocalMem = true, index_t _nRowsWG = 0, index_t _nColsWG = 0) { index_t M = _M; index_t N = _N; auto mA = make_matrix_view(_mA, M, N, _lda); @@ -887,24 +887,39 @@ typename sb_handle_t::event_t _ger_impl( typename VectorViewType::type vy = make_vector_view(_vy, _incy, N); - const index_t localSize = - (_localSize == 0) ? sb_handle.get_work_group_size() : _localSize; - const index_t nRowsWG = (_nRowsWG == 0) ? localSize : std::min(M, _nRowsWG); + _localSize = (_localSize == 0) ? sb_handle.get_work_group_size() : _localSize; + _nRowsWG = (_nRowsWG == 0) ? _localSize : _nRowsWG; + _nColsWG = (_nColsWG == 0) ? _localSize : _nColsWG; - const index_t nColsWG = (_nColsWG == 0) ? localSize : std::min(N, _nColsWG); + assert(_localSize % _nRowsWG == 0); + assert((_nRowsWG * _nColsWG) % _localSize == 0); + assert(_nColsWG % (_localSize / _nRowsWG) == 0); - const index_t scratchPadSize = - (_localSize == 0) ? localSize : _scratchPadSize; + if (_useLocalMem) { + assert((_nRowsWG <= _localSize) && (_nColsWG <= _localSize)); + } else { + std::vector subgroup_sizes = + sb_handle.get_queue() + .get_device() + .template get_info(); + size_t min_subgroup_size = *subgroup_sizes.begin(); + size_t max_subgroup_size = *subgroup_sizes.rbegin(); + assert(((_nRowsWG * _nColsWG) / _localSize) <= min_subgroup_size); + assert(_nRowsWG % max_subgroup_size == 0); + } - const index_t nWGPerCol = (N - 1) / nColsWG + 1; - const index_t nWGPerRow = (M - 1) / nRowsWG + 1; - const index_t globalSize = localSize * nWGPerRow * nWGPerCol; + const index_t nWGPerCol = (N - 1) / _nColsWG + 1; + const index_t nWGPerRow = (M - 1) / _nRowsWG + 1; + const index_t globalSize = _localSize * nWGPerRow * nWGPerCol; typename sb_handle_t::event_t ret; auto assignOp = - make_ger_col(mA, _alpha, vx, vy, nWGPerRow, nWGPerCol, scratchPadSize); - return sb_handle.execute(assignOp, localSize, globalSize, scratchPadSize, - _dependencies); + make_ger(mA, _alpha, vx, vy, _nRowsWG, _nColsWG, nWGPerRow, nWGPerCol); + + return _useLocalMem ? sb_handle.execute(assignOp, _localSize, globalSize, + _nRowsWG + _nColsWG, _dependencies) + : sb_handle.execute(assignOp, _localSize, globalSize, + _dependencies); } /*! _SYR. @@ -1280,10 +1295,30 @@ typename sb_handle_t::event_t inline _ger( container_t0 _vx, increment_t _incx, container_t1 _vy, increment_t _incy, container_t2 _mA, index_t _lda, const typename sb_handle_t::event_t& _dependencies) { - // TODO: Here we can use some heuristics to select localn global, local, and - // scratch size per device + index_t localSize = 0; + bool useLocalMem = true; + index_t nRowsWG = 0; + index_t nColsWG = 0; + +#if defined(INTEL_GPU) + localSize = 32; + useLocalMem = false; + nRowsWG = 32; + nColsWG = 8; +#elif defined(NVIDIA_GPU) + localSize = 256; + useLocalMem = (_N < 8192 && _M < 8192) ? false : true; + nRowsWG = 32; + nColsWG = 32; +#elif defined(AMD_GPU) + localSize = (_N < 8192 && _M < 8192) ? 512 : 256; + useLocalMem = (_N < 8192 && _M < 8192) ? false : true; + nRowsWG = (_N < 8192 && _M < 8192) ? 64 : 128; + nColsWG = (_N < 8192 && _M < 8192) ? 64 : 256; +#endif + return _ger_impl(sb_handle, _M, _N, _alpha, _vx, _incx, _vy, _incy, _mA, _lda, - _dependencies); + _dependencies, localSize, useLocalMem, nRowsWG, nColsWG); } template +PORTBLAS_INLINE Ger::Ger( + lhs_t &_l, value_t _scl, rhs_1_t &_r1, rhs_2_t &_r2, index_t &_nRowsWG, + index_t &_nColsWG, index_t &_nWG_row, index_t &_nWG_col) + : lhs_(_l), + scalar_(_scl), + rhs_1_(_r1), + rhs_2_(_r2), + nRowsWG_(_nRowsWG), + nColsWG_(_nColsWG), + nWG_row_(_nWG_row), + nWG_col_(_nWG_col) {} + +template +PORTBLAS_INLINE typename Ger::index_t +Ger::get_size() const { + return rhs_1_.get_size(); +} +template +PORTBLAS_INLINE bool Ger::valid_thread( + cl::sycl::nd_item<1> ndItem) const { + return true; +} + +template +PORTBLAS_INLINE typename Ger::value_t +Ger::eval(cl::sycl::nd_item<1> ndItem) { + using index_t = typename Ger::index_t; + + const index_t subgroup_size = ndItem.get_sub_group().get_local_range().get(0); + const index_t subgroups_per_col = nRowsWG_ / subgroup_size; + const index_t subgroups_per_group = + ndItem.get_sub_group().get_group_range().get(0); + + const index_t group_size = ndItem.get_local_range(0); + + // col_per_workitem <= subgroup_size + const index_t col_per_workitem = nColsWG_ * nRowsWG_ / group_size; + + const index_t group_id = ndItem.get_group(0); + const index_t idWFR = group_id % nWG_row_; + const index_t idWFC = group_id / nWG_row_; + + const index_t subgroup_id = ndItem.get_sub_group().get_group_id().get(0); + const index_t subgroup_local_id = + ndItem.get_sub_group().get_local_id().get(0); + + const index_t id_row0 = idWFR * nRowsWG_ + + subgroup_size * (subgroup_id % subgroups_per_col) + + subgroup_local_id; + const index_t id_col0 = + idWFC * nColsWG_ + col_per_workitem * (subgroup_id / subgroups_per_col); + + const index_t dimR = lhs_.get_size_row(); + const index_t dimC = lhs_.get_size_col(); + const bool id_row_active = id_row0 < dimR; + +#ifndef __ADAPTIVECPP__ + const value_t rhs_2 = (subgroup_local_id < col_per_workitem && + id_col0 + subgroup_local_id < dimC) + ? rhs_2_.eval(id_col0 + subgroup_local_id) + : 0; +#endif + + const value_t scal_rhs_1 = id_row_active ? scalar_ * rhs_1_.eval(id_row0) : 0; + + value_t prefetch_lhs_ = + (id_row_active && id_col0 < dimC) ? lhs_.eval(id_row0, id_col0) : 0; + + for (index_t sub_id_col = 0; sub_id_col < col_per_workitem; sub_id_col++) { + const value_t rhs_2_sub_id_col = +#ifndef __ADAPTIVECPP__ + cl::sycl::group_broadcast(ndItem.get_sub_group(), rhs_2, sub_id_col); +#else + rhs_2_.eval(id_col0 + sub_id_col); +#endif + if (id_row_active && id_col0 + sub_id_col < dimC) { + lhs_.eval(id_row0, id_col0 + sub_id_col) = + prefetch_lhs_ + scal_rhs_1 * rhs_2_sub_id_col; + prefetch_lhs_ = (id_col0 + sub_id_col + 1 < dimC) + ? lhs_.eval(id_row0, id_col0 + sub_id_col + 1) + : 0; + } + } + + return 0; +} + +template +template +PORTBLAS_INLINE typename Ger::value_t +Ger::eval(sharedT shrMem, + cl::sycl::nd_item<1> ndItem) { + using index_t = typename Ger::index_t; + + const index_t group_id = ndItem.get_group(0); + const index_t idWFR = group_id % nWG_row_; + const index_t idWFC = group_id / nWG_row_; + const index_t frs_row = idWFR * nRowsWG_; + const index_t group_local_id = ndItem.get_local_id(0); + + // group_size%nRowsWG_ == 0 + const index_t id_row0 = group_local_id % nRowsWG_; + const index_t id_row1 = frs_row + id_row0; + + index_t frs_col = idWFC * nColsWG_; + + const index_t dimR = lhs_.get_size_row(); + const index_t dimC = lhs_.get_size_col(); + + value_t *l_rhs_1 = shrMem.localAcc.get_pointer(); + value_t *l_rhs_2 = shrMem.localAcc.get_pointer() + nRowsWG_; + + // nRowsWG_ <= group_size + if (group_local_id < nRowsWG_) + l_rhs_1[group_local_id] = + (frs_row + group_local_id < dimR) + ? scalar_ * rhs_1_.eval(frs_row + group_local_id) + : 0; + + // nColsWG_ <= group_size + if (group_local_id < nColsWG_) + l_rhs_2[group_local_id] = (frs_col + group_local_id < dimC) + ? rhs_2_.eval(frs_col + group_local_id) + : 0; + + const index_t group_size = ndItem.get_local_range(0); + + // nRowsWG_ * nColsWG_ % group_size == 0 + const index_t col_per_workitem = nRowsWG_ * nColsWG_ / group_size; + const index_t subgroup_col_id = group_local_id / nRowsWG_; + + const index_t id_col0 = subgroup_col_id * col_per_workitem; + const index_t id_col1 = frs_col + id_col0; + + value_t prefetch_lhs_ = + (id_row1 < dimR && id_col1 < dimC) ? lhs_.eval(id_row1, id_col1) : 0; + + ndItem.barrier(cl::sycl::access::fence_space::local_space); + + for (index_t id_col = 0; id_col < col_per_workitem; id_col++) { + const value_t val = l_rhs_1[id_row0] * l_rhs_2[id_col0 + id_col]; + if (id_row1 < dimR && id_col1 + id_col < dimC) { + lhs_.eval(id_row1, id_col1 + id_col) = prefetch_lhs_ + val; + prefetch_lhs_ = (id_col1 + id_col + 1 < dimC) + ? lhs_.eval(id_row1, id_col1 + id_col + 1) + : 0; + } + } + + return 0; +} + +template +PORTBLAS_INLINE void Ger::bind(cl::sycl::handler &h) { + lhs_.bind(h); + rhs_1_.bind(h); + rhs_2_.bind(h); +} +template +PORTBLAS_INLINE void +Ger::adjust_access_displacement() { + lhs_.adjust_access_displacement(); + rhs_1_.adjust_access_displacement(); + rhs_2_.adjust_access_displacement(); +} + /**** GER BY ROWS M ROWS x N BLOCK USING PROPERLY THE SHARED MEMORY ****/ // template template