Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed is_device_copyable error for IndexValueTuple #481

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions include/operations/blas_constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -265,4 +265,16 @@ struct constant_pair {

} // namespace blas

#define PORTBLAS_DEVICE_COPY(T1, T2) \
template <> \
struct sycl::is_device_copyable<blas::IndexValueTuple<T1, T2>> \
: std::true_type {};

PORTBLAS_DEVICE_COPY(int32_t, float)
PORTBLAS_DEVICE_COPY(int32_t, double)
PORTBLAS_DEVICE_COPY(int64_t, float)
PORTBLAS_DEVICE_COPY(int64_t, double)

#undef PORTBLAS_DEVICE_COPY
muhammad-tanvir-1211 marked this conversation as resolved.
Show resolved Hide resolved

#endif // BLAS_CONSTANTS_H
71 changes: 40 additions & 31 deletions src/interface/blas2_interface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -345,10 +345,10 @@ template <uint32_t subgroup_size, uint32_t subgroups, uplo_type uplo,
transpose_type trn, diag_type diag, typename sb_handle_t,
typename index_t, typename container_t0, typename container_t1,
typename increment_t>
typename sb_handle_t::event_t _trsv_impl(sb_handle_t& sb_handle, index_t _N,
container_t0 _mA, index_t _lda,
container_t1 _vx, increment_t _incx,
const typename sb_handle_t::event_t& _dependencies) {
typename sb_handle_t::event_t _trsv_impl(
sb_handle_t& sb_handle, index_t _N, container_t0 _mA, index_t _lda,
container_t1 _vx, increment_t _incx,
const typename sb_handle_t::event_t& _dependencies) {
#if (SYCL_LANGUAGE_VERSION < 202000) || (defined __HIPSYCL__)
throw std::runtime_error("trsv requires SYCL 2020");
#else
Expand Down Expand Up @@ -392,7 +392,8 @@ typename sb_handle_t::event_t _trsv_impl(sb_handle_t& sb_handle, index_t _N,
auto ret = sb_handle.execute(
trsv, static_cast<index_t>(sub_num * subgroup_size),
roundUp<index_t>(sub_num * _N, sub_num * subgroup_size),
static_cast<index_t>(subgroup_size * (subgroup_size + 2 + sub_num)), _dependencies);
static_cast<index_t>(subgroup_size * (subgroup_size + 2 + sub_num)),
_dependencies);

blas::helper::enqueue_deallocate(ret, sync_buffer, queue);

Expand Down Expand Up @@ -727,18 +728,16 @@ template <uint32_t subgroup_size, uint32_t subgroups, uplo_type uplo,
transpose_type trn, diag_type diag, typename sb_handle_t,
typename index_t, typename container_t0, typename container_t1,
typename increment_t>
typename sb_handle_t::event_t _tbsv_impl(sb_handle_t& sb_handle, index_t _N,
index_t _K, container_t0 _mA,
index_t _lda, container_t1 _vx,
increment_t _incx,
const typename sb_handle_t::event_t& _dependencies) {
typename sb_handle_t::event_t _tbsv_impl(
sb_handle_t& sb_handle, index_t _N, index_t _K, container_t0 _mA,
index_t _lda, container_t1 _vx, increment_t _incx,
const typename sb_handle_t::event_t& _dependencies) {
#if (SYCL_LANGUAGE_VERSION < 202000) || (defined __HIPSYCL__)
throw std::runtime_error("tbsv requires SYCL 2020");
#else
static_assert(subgroup_size % subgroups == 0,
"`subgroups` needs to be a multiple of `subgroup_size`.");


if (_K >= _N) throw std::invalid_argument("Erroneous parameter: _K >= _N");

using one = constant<increment_t, const_val::one>;
Expand Down Expand Up @@ -780,7 +779,8 @@ typename sb_handle_t::event_t _tbsv_impl(sb_handle_t& sb_handle, index_t _N,
auto ret = sb_handle.execute(
tbsv, static_cast<index_t>(sub_num * subgroup_size),
roundUp<index_t>(sub_num * _N, sub_num * subgroup_size),
static_cast<index_t>(subgroup_size * (subgroup_size + 2 + sub_num)), _dependencies);
static_cast<index_t>(subgroup_size * (subgroup_size + 2 + sub_num)),
_dependencies);

blas::helper::enqueue_deallocate(ret, sync_buffer, queue);

Expand All @@ -792,10 +792,9 @@ template <uint32_t subgroup_size, uint32_t subgroups, uplo_type uplo,
transpose_type trn, diag_type diag, typename sb_handle_t,
typename index_t, typename container_t0, typename container_t1,
typename increment_t>
typename sb_handle_t::event_t _tpsv_impl(sb_handle_t& sb_handle, index_t _N,
container_t0 _mA, container_t1 _vx,
increment_t _incx,
const typename sb_handle_t::event_t& _dependencies) {
typename sb_handle_t::event_t _tpsv_impl(
sb_handle_t& sb_handle, index_t _N, container_t0 _mA, container_t1 _vx,
increment_t _incx, const typename sb_handle_t::event_t& _dependencies) {
#if (SYCL_LANGUAGE_VERSION < 202000) || (defined __HIPSYCL__)
throw std::runtime_error("tpsv requires SYCL 2020");
#else
Expand Down Expand Up @@ -823,8 +822,18 @@ typename sb_handle_t::event_t _tpsv_impl(sb_handle_t& sb_handle, index_t _N,
: ((roundUp<index_t>(_N, subgroup_size) / subgroup_size) - 1);
sync_vec[1] = sync_vec[0];

auto sync_buffer =
blas::make_sycl_iterator_buffer<int32_t>(sync_vec, sync_vec.size());
constexpr bool is_usm = std::is_pointer<container_t0>::value;
auto queue = sb_handle.get_queue();

auto sync_buffer = blas::helper::allocate < is_usm
? blas::helper::AllocType::usm
: blas::helper::AllocType::buffer,
int32_t > (sync_vec.size(), queue);

auto copy_sync = blas::helper::copy_to_device<int32_t>(
queue, sync_vec.data(), sync_buffer, sync_vec.size());
sb_handle.wait(copy_sync);

auto sync =
make_vector_view(sync_buffer, one_increment_t::value(), sync_vec.size());

Expand All @@ -833,11 +842,13 @@ typename sb_handle_t::event_t _tpsv_impl(sb_handle_t& sb_handle, index_t _N,
vx, mA, sync);

const index_t sub_num = subgroups;
return sb_handle.execute(
auto ret = sb_handle.execute(
tpsv, static_cast<index_t>(sub_num * subgroup_size),
roundUp<index_t>(sub_num * _N, sub_num * subgroup_size),
static_cast<index_t>(subgroup_size * (subgroup_size + 2 + sub_num)),
_dependencies);
blas::helper::enqueue_deallocate(ret, sync_buffer, queue);
return ret;
#endif
}

Expand Down Expand Up @@ -1329,11 +1340,10 @@ typename sb_handle_t::event_t inline _spr2(
template <typename sb_handle_t, typename index_t, typename element_t,
typename container_t0, typename increment_t, typename container_t1,
typename container_t2>
typename sb_handle_t::event_t inline _syr2(sb_handle_t& sb_handle, char _Uplo,
index_t _N, element_t _alpha,
container_t0 _vx, increment_t _incx,
container_t1 _vy, increment_t _incy,
container_t2 _mA, index_t _lda,
typename sb_handle_t::event_t inline _syr2(
sb_handle_t& sb_handle, char _Uplo, index_t _N, element_t _alpha,
container_t0 _vx, increment_t _incx, container_t1 _vy, increment_t _incy,
container_t2 _mA, index_t _lda,
const typename sb_handle_t::event_t& _dependencies) {
// TODO: Here we can use some heuristics to select localn global, local, and
// scratch size per device
Expand Down Expand Up @@ -1366,17 +1376,16 @@ typename sb_handle_t::event_t _tpmv(
sb_handle_t& sb_handle, char _Uplo, char _trans, char _Diag, index_t _N,
container_t0 _mA, container_t1 _vx, increment_t _incx,
const typename sb_handle_t::event_t& _dependencies) {
INST_UPLO_TRANS_DIAG(blas::tpmv::backend::_tpmv, sb_handle, _N, _mA, _vx,
_incx, _dependencies)
INST_UPLO_TRANS_DIAG(blas::tpmv::backend::_tpmv, sb_handle, _N, _mA, _vx,
_incx, _dependencies)
}

template <typename sb_handle_t, typename index_t, typename container_t0,
typename container_t1, typename increment_t>
typename sb_handle_t::event_t _tpsv(sb_handle_t& sb_handle, char _Uplo,
char _trans, char _Diag, index_t _N,
container_t0 _mA, container_t1 _vx,
increment_t _incx,
const typename sb_handle_t::event_t& _dependencies) {
typename sb_handle_t::event_t _tpsv(
sb_handle_t& sb_handle, char _Uplo, char _trans, char _Diag, index_t _N,
container_t0 _mA, container_t1 _vx, increment_t _incx,
const typename sb_handle_t::event_t& _dependencies) {
INST_UPLO_TRANS_DIAG(blas::tpsv::backend::_tpsv, sb_handle, _N, _mA, _vx,
_incx, _dependencies)
}
Expand Down
Empty file removed src/operations/blas2/trsv.hpp
Empty file.