Skip to content

Commit

Permalink
Put all device functions into ttg::device namespace
Browse files Browse the repository at this point in the history
Except for ttg::Buffer, which is separate and may be used by non-device
tasks/data structures.

The following were renamed:
- to_device -> select
- wait_kernel -> wait

Signed-off-by: Joseph Schuchart <[email protected]>
  • Loading branch information
devreal committed Dec 18, 2023
1 parent 9afe0c0 commit f424bcd
Show file tree
Hide file tree
Showing 5 changed files with 443 additions and 507 deletions.
30 changes: 15 additions & 15 deletions examples/potrf/potrf.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@

#if defined(TTG_HAVE_CUDART)
#define ES ttg::ExecutionSpace::CUDA
#define TASKRET -> ttg::device_task
#define TASKRET -> ttg::device::Task
#include <cusolverDn.h>
#elif defined(TTG_HAVE_HIP)
#define ES ttg::ExecutionSpace::HIP
#define TASKRET -> ttg::device_task
#define TASKRET -> ttg::device::Task
#include <hipsolver/hipsolver.h>
#include <hipblas/hipblas.h>
#else
Expand Down Expand Up @@ -134,14 +134,14 @@ namespace potrf {
//auto norms_s = ttg::make_scratch(norms.data(), ttg::scope::Allocate, norms.size());
/* the workspace and the devInfo must be device-level pointers */
//co_await ttg::to_device(tile_kk.buffer(), devWS, devInfo, norms_s);
co_await ttg::to_device(tile_kk.buffer(), devWS, devInfo);
co_await ttg::device::select(tile_kk.buffer(), devWS, devInfo);

/* compute the norm at input */
static_assert(std::is_same_v<double, T>, "Norm debugging only implementation for T=double");
device_norm(tile_kk, &norms[0]);
#else
/* the workspace and the devInfo must be device-level pointers */
co_await ttg::to_device(tile_kk.buffer(), devWS, devInfo);
co_await ttg::device::select(tile_kk.buffer(), devWS, devInfo);
#endif // DEBUG_TILES_VALUES

int device = ttg::device::current_device();
Expand All @@ -159,14 +159,14 @@ namespace potrf {
static_assert(std::is_same_v<double, T>, "Verification only implementation for T=double");
device_norm(tile_kk, &norms[1]);
/* wait for the kernel to complete */
co_await ttg::wait_kernel(devInfo);
co_await ttg::device::wait(devInfo);
// check that we got the input tile we expected
assert(check_norm(tile_kk.norm(), norms[0]));
// set the new norm
tile_kk.set_norm(norms[1]);
#else
/* wait for the kernel to complete */
co_await ttg::wait_kernel(devInfo);
co_await ttg::device::wait(devInfo);
#endif // DEBUG_TILES_VALUES

delete[] hostWS;
Expand Down Expand Up @@ -268,9 +268,9 @@ namespace potrf {
#ifdef DEBUG_TILES_VALUES
std::array<T, 3> norms; // input for tile_kk & tile_mk and output
//auto norms_s = ttg::make_scratch(norms.data(), ttg::scope::Allocate, norms.size());
co_await ttg::to_device(tile_kk.buffer(), tile_mk.buffer());
co_await ttg::device::select(tile_kk.buffer(), tile_mk.buffer());
#else
co_await ttg::to_device(tile_kk.buffer(), tile_mk.buffer());
co_await ttg::device::select(tile_kk.buffer(), tile_mk.buffer());
#endif // DEBUG_TILES_VALUES

int device = ttg::device::current_device();
Expand Down Expand Up @@ -306,7 +306,7 @@ namespace potrf {
/* compute the norms at input */
device_norm(tile_mk, &norms[2]);
/* wait for the kernel to complete */
co_await ttg::wait_kernel();
co_await ttg::device::wait();
// check that we got the input tiles we expected
assert(check_norm(tile_kk.norm(), norms[0]));
assert(check_norm(tile_mk.norm(), norms[1]));
Expand Down Expand Up @@ -400,12 +400,12 @@ namespace potrf {
#ifdef DEBUG_TILES_VALUES
std::array<T, 3> norms; // input for tile_kk & tile_mk and output
//auto norms_s = ttg::make_scratch(norms.data(), ttg::scope::Allocate, norms.size());
co_await ttg::to_device(tile_kk.buffer(), tile_mk.buffer());
co_await ttg::device::select(tile_kk.buffer(), tile_mk.buffer());
/* compute the norms at input */
device_norm(tile_mk, &norms[0]);
device_norm(tile_kk, &norms[1]);
#else
co_await ttg::to_device(tile_kk.buffer(), tile_mk.buffer());
co_await ttg::device::select(tile_kk.buffer(), tile_mk.buffer());
#endif // DEBUG_TILES_VALUES

int device = ttg::device::current_device();
Expand Down Expand Up @@ -435,7 +435,7 @@ namespace potrf {
/* compute the norm at output */
device_norm(tile_kk, &norms[2]);
/* wait for the kernel to complete */
co_await ttg::wait_kernel();
co_await ttg::device::wait();
// check that we got the input tiles we expected
assert(check_norm(tile_mk.norm(), norms[0]));
assert(check_norm(tile_kk.norm(), norms[1]));
Expand Down Expand Up @@ -526,14 +526,14 @@ namespace potrf {
#ifdef DEBUG_TILES_VALUES
std::array<T, 4> norms; // input for tile_mk & tile_nk & tile_mn and output
//auto norms_s = ttg::make_scratch(norms.data(), ttg::scope::Allocate, norms.size());
co_await ttg::to_device(tile_mk.buffer(), tile_nk.buffer(), tile_mn.buffer());
co_await ttg::device::select(tile_mk.buffer(), tile_nk.buffer(), tile_mn.buffer());

/* compute the norms at input */
device_norm(tile_mk, &norms[0]);
device_norm(tile_nk, &norms[1]);
device_norm(tile_mn, &norms[2]);
#else
co_await ttg::to_device(tile_mk.buffer(), tile_nk.buffer(), tile_mn.buffer());
co_await ttg::device::select(tile_mk.buffer(), tile_nk.buffer(), tile_mn.buffer());
#endif // DEBUG_TILES_VALUES

int device = ttg::device::current_device();
Expand Down Expand Up @@ -563,7 +563,7 @@ namespace potrf {
/* compute the norm at output */
device_norm(tile_mn, &norms[3]);
/* wait for the kernel to complete */
co_await ttg::wait_kernel();
co_await ttg::device::wait();
// check that we got the input tiles we expected
assert(check_norm(tile_mk.norm(), norms[0]));
assert(check_norm(tile_nk.norm(), norms[1]));
Expand Down
6 changes: 3 additions & 3 deletions examples/spmm/spmm_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -812,7 +812,7 @@ class SpMM25D {
}
}

ttg::device_task op(const Key<3> &ijk, typename baseT::input_refs_tuple_type &&_ijk,
ttg::device::Task op(const Key<3> &ijk, typename baseT::input_refs_tuple_type &&_ijk,
std::tuple<Out<Key<2>, Blk>, Out<Key<3>, Blk>> &result) {
const auto i = ijk[0];
const auto j = ijk[1];
Expand All @@ -830,7 +830,7 @@ class SpMM25D {
}

/* pull all buffers onto the device */
co_await ttg::to_device(A.b, B.b, C.b);
co_await ttg::device::select(A.b, B.b, C.b);

/* everything is on the device, call the gemm */
device_gemm(C, A, B);
Expand All @@ -844,7 +844,7 @@ class SpMM25D {
(have_next_k ? std::to_string(next_k) : "does not exist"));

/* wait for the kernel to complete */
co_await ttg::wait_kernel();
co_await ttg::device::wait();


// compute the contrib, pass the running total to the next flow, if needed
Expand Down
Loading

0 comments on commit f424bcd

Please sign in to comment.