Skip to content

Commit

Permalink
First dips at moving execution space selection to ttg::device::Task
Browse files Browse the repository at this point in the history
Signed-off-by: Joseph Schuchart <[email protected]>
  • Loading branch information
devreal committed Oct 28, 2024
1 parent 7fb4027 commit f9971f5
Show file tree
Hide file tree
Showing 13 changed files with 648 additions and 374 deletions.
20 changes: 18 additions & 2 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ if (TARGET tiledarray)
COMPILE_DEFINITIONS BLOCK_SPARSE_GEMM=1;BTAS_TARGET_MAX_INDEX_RANK=2)

add_ttg_executable(testing_dpotrf potrf/testing_dpotrf.cc LINK_LIBRARIES tiledarray lapackpp)
add_ttg_executable(testing_dpotrf_host potrf/testing_dpotrf.cc
LINK_LIBRARIES tiledarray lapackpp
COMPILE_DEFINITIONS TTG_ENABLE_DEV_HOST=1)
add_ttg_executable(testing_dtrtri potrf/testing_dtrtri.cc LINK_LIBRARIES tiledarray lapackpp)
add_ttg_executable(testing_dlauum potrf/testing_dlauum.cc LINK_LIBRARIES tiledarray lapackpp)
add_ttg_executable(testing_dpoinv potrf/testing_dpoinv.cc LINK_LIBRARIES tiledarray lapackpp)
Expand Down Expand Up @@ -50,14 +53,27 @@ if (TARGET tiledarray)
endif()

if (TTG_HAVE_CUDA)
add_ttg_executable(chain-ttg-cuda task-benchmarks/chain-ttg-dev.cc LINK_LIBRARIES tiledarray RUNTIMES "parsec")
add_ttg_executable(chain-ttg-dev-cuda task-benchmarks/chain-ttg-dev.cc
COMPILE_DEFINITIONS CHAIN_CUDA=1
LINK_LIBRARIES tiledarray
RUNTIMES "parsec")
endif(TTG_HAVE_CUDA)

if (TTG_HAVE_HIP)
add_ttg_executable(chain-ttg-hip task-benchmarks/chain-ttg-dev.cc LINK_LIBRARIES tiledarray RUNTIMES "parsec")
add_ttg_executable(chain-ttg-dev-hip task-benchmarks/chain-ttg-dev.cc
COMPILE_DEFINITIONS CHAIN_HIP=1
LINK_LIBRARIES tiledarray
RUNTIMES "parsec")
endif(TTG_HAVE_HIP)
endif()

add_ttg_executable(chain-ttg-host task-benchmarks/chain-ttg.cc)

add_ttg_executable(chain-ttg-dev-host task-benchmarks/chain-ttg-dev.cc
COMPILE_DEFINITIONS CHAIN_HOST=1
LINK_LIBRARIES tiledarray
RUNTIMES "parsec")

if (TARGET MADworld)
add_ttg_executable(madness-1d madness/madness-1d/madness-1d.cc RUNTIMES "mad")
if (TARGET blaspp) #(CBLAS_FOUND AND MKL_FOUND)
Expand Down
107 changes: 62 additions & 45 deletions examples/potrf/potrf.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,19 @@
#include "util.h"
#include "../devblas_helper.h"

#if (defined(TTG_ENABLE_CUDA) || defined(TTG_ENABLE_HIP))
#if (defined(TTG_ENABLE_CUDA) || defined(TTG_ENABLE_HIP) || defined(TTG_ENABLE_DEV_HOST))
#define ENABLE_DEVICE_KERNEL 1
#endif

#if defined(TTG_HAVE_CUDART)
#define ES ttg::ExecutionSpace::CUDA
#define TASKRET -> ttg::device::Task
#include <cusolverDn.h>
#elif defined(TTG_ENABLE_HIP)
#define ES ttg::ExecutionSpace::HIP
#define TASKRET -> ttg::device::Task
#include <hipsolver/hipsolver.h>
#include <hipblas/hipblas.h>
#else
#define ES ttg::ExecutionSpace::Host
#define TASKRET -> void
#endif

namespace potrf {
Expand All @@ -35,21 +32,21 @@ namespace potrf {
#if defined(ENABLE_DEVICE_KERNEL)
static int device_potrf_workspace_size(MatrixTile<double> &A) {
int Lwork;
#if defined(TTG_ENABLE_CUDA)
#if defined(TTG_ENABLE_CUDA)
cusolverDnDpotrf_bufferSize(cusolver_handle(),
CUBLAS_FILL_MODE_LOWER, A.cols(),
nullptr, A.lda(),
&Lwork);
return Lwork;
#elif defined(TTG_ENABLE_HIP)
#elif defined(TTG_ENABLE_HIP)
hipsolverDnDpotrf_bufferSize(hipsolver_handle(),
HIPSOLVER_FILL_MODE_LOWER, A.cols(),
nullptr, A.lda(),
&Lwork);
return Lwork;
#else
#else
return 0;
#endif
#endif
}

static void device_potrf(MatrixTile<double> &A, double *workspace, int Lwork, int *devInfo) {
Expand All @@ -64,13 +61,16 @@ namespace potrf {
A.buffer().current_device_ptr(), A.lda(),
workspace, Lwork,
devInfo);
#elif defined(TTG_ENABLE_HIP)
#elif defined(TTG_ENABLE_HIP)
hipsolverDpotrf(hipsolver_handle(),
HIPSOLVER_FILL_MODE_LOWER, A.cols(),
A.buffer().current_device_ptr(), A.lda(),
workspace, Lwork,
devInfo);
#endif
#else
auto info = lapack::potrf(lapack::Uplo::Lower, A.rows(), A.buffer().current_device_ptr(), A.lda());
assert(info == 0);
#endif
}

static void device_norm(const MatrixTile<double> &A, double *norm) {
Expand All @@ -81,9 +81,11 @@ namespace potrf {
auto handle = cublas_handle();
//double n = 1.0;
cublasDnrm2(handle, size, buffer, 1, norm);
#elif defined(TTG_ENABLE_HIP)
#elif defined(TTG_ENABLE_HIP)
hipblasDnrm2(hipblas_handle(), size, buffer, 1, norm);
#endif
#else
*norm = blas::nrm2(size, buffer, 1);
#endif
}
#endif // ENABLE_DEVICE_KERNEL

Expand All @@ -99,7 +101,8 @@ namespace potrf {
//std::cout << "Creating CUDA POTRF task " << std::endl;
auto f_dev = [=, iallocator = std::move(iallocator)]
(const Key1& key, MatrixTile<T>&& tile_kk,
std::tuple<ttg::Out<Key2, MatrixTile<T>>, ttg::Out<Key2, MatrixTile<T>>>& out) TASKRET {
std::tuple<ttg::Out<Key2, MatrixTile<T>>, ttg::Out<Key2, MatrixTile<T>>>& out)
-> ttg::device::Task<ES> {
const auto K = key[0];

/* compute successors before submitting the kernel
Expand Down Expand Up @@ -186,7 +189,7 @@ namespace potrf {
ttg::abort();
}
};
return ttg::make_tt<ES>(f_dev, ttg::edges(ttg::fuse(input, input_disp)), ttg::edges(output_result, output_trsm), "POTRF",
return ttg::make_tt(f_dev, ttg::edges(ttg::fuse(input, input_disp)), ttg::edges(output_result, output_trsm), "POTRF",
{"tile_kk/dispatcher"}, {"output_result", "output_trsm"});
#else /* defined(ENABLE_DEVICE_KERNEL) */
auto f = [=](const Key1& key, MatrixTile<T>&& tile_kk,
Expand Down Expand Up @@ -234,7 +237,7 @@ namespace potrf {
#if defined(ENABLE_DEVICE_KERNEL)
auto f = [=](const Key2& key, const MatrixTile<T>& tile_kk, MatrixTile<T>&& tile_mk,
std::tuple<ttg::Out<Key2, MatrixTile<T>>, ttg::Out<Key2, MatrixTile<T>>, ttg::Out<Key3, MatrixTile<T>>,
ttg::Out<Key3, MatrixTile<T>>>& out) TASKRET {
ttg::Out<Key3, MatrixTile<T>>>& out) -> ttg::device::Task<ES> {
const int M = key[0];
const int K = key[1]; // the column equals the outer most look K (same as PO)

Expand Down Expand Up @@ -302,6 +305,9 @@ namespace potrf {
mb, nb, &alpha,
tile_kk.buffer().current_device_ptr(), tile_kk.lda(),
tile_mk.buffer().current_device_ptr(), tile_mk.lda());
#else
blas::trsm(blas::Layout::ColMajor, blas::Side::Right, lapack::Uplo::Lower, blas::Op::Trans, blas::Diag::NonUnit,
mb, nb, 1.0, tile_kk.data(), tile_kk.lda(), tile_mk.data(), tile_mk.lda());

#endif

Expand All @@ -320,7 +326,7 @@ namespace potrf {
co_await ttg::device::forward(ttg::device::broadcast<0, 1, 2, 3>(std::make_tuple(key, Key2(K, M), keylist_row, keylist_col),
std::move(tile_mk), out));
};
return ttg::make_tt<ES>(f, ttg::edges(input_kk, ttg::fuse(input_mk, input_disp)),
return ttg::make_tt(f, ttg::edges(input_kk, ttg::fuse(input_mk, input_disp)),
ttg::edges(output_result, output_diag, output_row, output_col), "TRSM",
{"tile_kk", "tile_mk/dispatcher"}, {"output_result", "tile_mk", "output_row", "output_col"});
#else // defined(ENABLE_DEVICE_KERNEL)
Expand Down Expand Up @@ -386,8 +392,8 @@ namespace potrf {
ttg::Edge<Key2, MatrixTile<typename MatrixT::element_type>>& output_syrk) {
using T = typename MatrixT::element_type;
#if defined(ENABLE_DEVICE_KERNEL)
auto f = [=](const Key2& key, const MatrixTile<T>& tile_mk, MatrixTile<T>&& tile_kk,
std::tuple<ttg::Out<Key1, MatrixTile<T>>, ttg::Out<Key2, MatrixTile<T>>>& out) TASKRET {
auto f = [=](const Key2& key, const MatrixTile<T>& tile_mk, MatrixTile<T>&& tile_kk)
-> ttg::device::Task<ES> {
const int K = key[0];
const int M = key[1];

Expand Down Expand Up @@ -432,6 +438,9 @@ namespace potrf {
mb, nb, &alpha,
tile_mk.buffer().current_device_ptr(), tile_mk.lda(), &beta,
tile_kk.buffer().current_device_ptr(), tile_kk.lda());
#else
blas::syrk(blas::Layout::ColMajor, lapack::Uplo::Lower, blas::Op::NoTrans, mb, nb, -1.0, tile_mk.data(),
tile_mk.lda(), 1.0, tile_kk.data(), tile_kk.lda());
#endif

#ifdef DEBUG_TILES_VALUES
Expand All @@ -449,18 +458,17 @@ namespace potrf {
if (M == K + 1) {
/* send the tile to potrf */
if (ttg::tracing()) ttg::print("SYRK(", key, "): sending output to POTRF(", Key1{K + 1}, ")");
co_await ttg::device::send<0>(Key1(K + 1), std::move(tile_kk), out);
co_await ttg::device::send<0>(Key1(K + 1), std::move(tile_kk));
} else {
/* send output to next syrk */
if (ttg::tracing()) ttg::print("SYRK(", key, "): sending output to SYRK(", Key2{K + 1, M}, ")");
co_await ttg::device::send<1>(Key2(K + 1, M), std::move(tile_kk), out);
co_await ttg::device::send<1>(Key2(K + 1, M), std::move(tile_kk));
}
};
return ttg::make_tt<ES>(f, ttg::edges(input_mk, ttg::fuse(input_kk, input_disp)), ttg::edges(output_potrf, output_syrk),
return ttg::make_tt(f, ttg::edges(input_mk, ttg::fuse(input_kk, input_disp)), ttg::edges(output_potrf, output_syrk),
"SYRK", {"tile_mk", "tile_kk/dispatcher"}, {"output_potrf", "output_syrk"});
#else // defined(ENABLE_DEVICE_KERNEL)
auto f = [=](const Key2& key, const MatrixTile<T>& tile_mk, MatrixTile<T>&& tile_kk,
std::tuple<ttg::Out<Key1, MatrixTile<T>>, ttg::Out<Key2, MatrixTile<T>>>& out) {
auto f = [=](const Key2& key, const MatrixTile<T>& tile_mk, MatrixTile<T>&& tile_kk) {
const int K = key[0];
const int M = key[1];

Expand All @@ -487,11 +495,11 @@ namespace potrf {
if (M == K + 1) {
/* send the tile to potrf */
if (ttg::tracing()) ttg::print("SYRK(", key, "): sending output to POTRF(", Key1{K + 1}, ")");
ttg::send<0>(Key1(K + 1), std::move(tile_kk), out);
ttg::send<0>(Key1(K + 1), std::move(tile_kk));
} else {
/* send output to next syrk */
if (ttg::tracing()) ttg::print("SYRK(", key, "): sending output to SYRK(", Key2{K + 1, M}, ")");
ttg::send<1>(Key2(K + 1, M), std::move(tile_kk), out);
ttg::send<1>(Key2(K + 1, M), std::move(tile_kk));
}
};
return ttg::make_tt(f, ttg::edges(input_mk, ttg::fuse(input_kk, input_disp)), ttg::edges(output_potrf, output_syrk),
Expand All @@ -509,8 +517,8 @@ namespace potrf {
ttg::Edge<Key3, MatrixTile<typename MatrixT::element_type>>& output_gemm) {
using T = typename MatrixT::element_type;
#if defined(ENABLE_DEVICE_KERNEL)
auto f = [=](const Key3& key, const MatrixTile<T>& tile_mk, const MatrixTile<T>& tile_nk, MatrixTile<T>&& tile_mn,
std::tuple<ttg::Out<Key2, MatrixTile<T>>, ttg::Out<Key3, MatrixTile<T>>>& out) TASKRET {
auto f = [=](const Key3& key, const MatrixTile<T>& tile_mk, const MatrixTile<T>& tile_nk, MatrixTile<T>&& tile_mn)
-> ttg::device::Task<ES> {
const int M = key[0];
const int N = key[1];
const int K = key[2];
Expand Down Expand Up @@ -559,6 +567,10 @@ namespace potrf {
tile_mk.buffer().current_device_ptr(), tile_mk.lda(),
tile_nk.buffer().current_device_ptr(), tile_nk.lda(), &beta,
tile_mn.buffer().current_device_ptr(), tile_mn.lda());
#else
blas::gemm(blas::Layout::ColMajor, blas::Op::NoTrans, blas::Op::Trans, tile_mk.rows(), tile_nk.rows(),
tile_nk.cols(), -1.0, tile_mk.data(), tile_mk.lda(), tile_nk.data(), tile_nk.lda(), 1.0,
tile_mn.data(), tile_mn.lda());
#endif


Expand All @@ -578,19 +590,18 @@ namespace potrf {
if (N == K + 1) {
/* send the tile to trsm */
if (ttg::tracing()) ttg::print("GEMM(", key, "): sending output to TRSM(", Key2{M, N}, ")");
co_await ttg::device::send<0>(Key2(M, N), std::move(tile_mn), out);
co_await ttg::device::send<0>(Key2(M, N), std::move(tile_mn));
} else {
/* send the tile to the next gemm */
if (ttg::tracing()) ttg::print("GEMM(", key, "): sending output to GEMM(", Key3{M, N, K + 1}, ")");
co_await ttg::device::send<1>(Key3(M, N, K + 1), std::move(tile_mn), out);
co_await ttg::device::send<1>(Key3(M, N, K + 1), std::move(tile_mn));
}
};
return ttg::make_tt<ES>(f, ttg::edges(input_mk, input_nk, ttg::fuse(input_disp, input_mn)),
return ttg::make_tt(f, ttg::edges(input_mk, input_nk, ttg::fuse(input_disp, input_mn)),
ttg::edges(output_trsm, output_gemm), "GEMM", {"input_mk", "input_kn", "input_mn/dispatcher"},
{"output_trsm", "outout_gemm"});
#else // defined(ENABLE_DEVICE_KERNEL)
auto f = [=](const Key3& key, const MatrixTile<T>& tile_mk, const MatrixTile<T>& tile_nk, MatrixTile<T>&& tile_mn,
std::tuple<ttg::Out<Key2, MatrixTile<T>>, ttg::Out<Key3, MatrixTile<T>>>& out) {
auto f = [=](const Key3& key, const MatrixTile<T>& tile_mk, const MatrixTile<T>& tile_nk, MatrixTile<T>&& tile_mn) {
const int M = key[0];
const int N = key[1];
const int K = key[2];
Expand All @@ -617,11 +628,11 @@ namespace potrf {
if (N == K + 1) {
/* send the tile to trsm */
if (ttg::tracing()) ttg::print("GEMM(", key, "): sending output to TRSM(", Key2{M, N}, ")");
ttg::send<0>(Key2(M, N), std::move(tile_mn), out);
ttg::send<0>(Key2(M, N), std::move(tile_mn));
} else {
/* send the tile to the next gemm */
if (ttg::tracing()) ttg::print("GEMM(", key, "): sending output to GEMM(", Key3{M, N, K + 1}, ")");
ttg::send<1>(Key3(M, N, K + 1), std::move(tile_mn), out);
ttg::send<1>(Key3(M, N, K + 1), std::move(tile_mn));
}
};
return ttg::make_tt(f, ttg::edges(input_mk, input_nk, ttg::fuse(input_disp, input_mn)),
Expand All @@ -634,33 +645,31 @@ namespace potrf {
auto make_dispatcher(ttg::Edge<Key2, MatrixTile<T>>& input, ttg::Edge<Key1, MatrixTile<T>>& to_potrf,
ttg::Edge<Key2, MatrixTile<T>>& to_trsm, ttg::Edge<Key2, MatrixTile<T>>& to_syrk,
ttg::Edge<Key3, MatrixTile<T>>& to_gemm) {
auto f = [=](const Key2& key, const MatrixTile<T>& tile,
std::tuple<ttg::Out<Key1, MatrixTile<T>>, ttg::Out<Key2, MatrixTile<T>>, ttg::Out<Key2, MatrixTile<T>>,
ttg::Out<Key3, MatrixTile<T>>>& out) {
auto f = [=](const Key2& key, const MatrixTile<T>& tile) {
if (ttg::tracing()) ttg::print("POTRF_Dispatch(", key, ")");
if (0 == key[0] && 0 == key[1]) {
// First element goes to POTRF
if (ttg::tracing()) ttg::print("POTRF_Dispatch(", key, ") sending to POTRF(", Key1{key[0]}, ")");
ttg::send<0>(Key1{key[0]}, tile, out);
ttg::send<0>(Key1{key[0]}, tile);
return;
}
if (key[0] == key[1]) {
// Other diagonal elements go to SYRK
if (ttg::tracing()) ttg::print("POTRF_Dispatch(", key, ") sending to SYRK(", Key2{0, key[0]}, ")");
ttg::send<2>(Key2{0, key[0]}, tile, out);
ttg::send<2>(Key2{0, key[0]}, tile);
return;
}
// We only consider the lower triangular
assert(key[0] > key[1]);
if (0 == key[1]) {
// First column goes to TRSM
if (ttg::tracing()) ttg::print("POTRF_Dispatch(", key, ") sending to TRSM(", key, ")");
ttg::send<1>(key, tile, out);
ttg::send<1>(key, tile);
return;
}
// Rest goes to GEMM
if (ttg::tracing()) ttg::print("POTRF_Dispatch(", key, ") sending to GEMM(", Key3{key[0], key[1], 0}, ")");
ttg::send<3>(Key3{key[0], key[1], 0}, tile, out);
ttg::send<3>(Key3{key[0], key[1], 0}, tile);
};

return ttg::make_tt(f, ttg::edges(input), ttg::edges(to_potrf, to_trsm, to_syrk, to_gemm), "POTRF Dispatch",
Expand Down Expand Up @@ -705,28 +714,36 @@ namespace potrf {
tt_potrf->set_keymap(keymap1);
tt_potrf->set_defer_writer(defer_write);
#ifdef ENABLE_DEVICE_KERNEL
tt_potrf->set_devicemap(devmap1);
if constexpr (ES != ttg::ExecutionSpace::Host) {
tt_potrf->set_devicemap(devmap1);
}
#endif // 0

auto tt_trsm = make_trsm(A, disp_trsm, potrf_trsm, gemm_trsm, trsm_syrk, trsm_gemm_row, trsm_gemm_col, output);
tt_trsm->set_keymap(keymap2a);
tt_trsm->set_defer_writer(defer_write);
#ifdef ENABLE_DEVICE_KERNEL
tt_trsm->set_devicemap(devmap2a);
if constexpr (ES != ttg::ExecutionSpace::Host) {
tt_trsm->set_devicemap(devmap2a);
}
#endif // 0

auto tt_syrk = make_syrk(A, disp_syrk, trsm_syrk, syrk_syrk, syrk_potrf, syrk_syrk);
tt_syrk->set_keymap(keymap2b);
tt_syrk->set_defer_writer(defer_write);
#ifdef ENABLE_DEVICE_KERNEL
tt_syrk->set_devicemap(devmap2b);
if constexpr (ES != ttg::ExecutionSpace::Host) {
tt_syrk->set_devicemap(devmap2b);
}
#endif // 0

auto tt_gemm = make_gemm(A, disp_gemm, trsm_gemm_row, trsm_gemm_col, gemm_gemm, gemm_trsm, gemm_gemm);
tt_gemm->set_keymap(keymap3);
tt_gemm->set_defer_writer(defer_write);
#ifdef ENABLE_DEVICE_KERNEL
tt_gemm->set_devicemap(devmap3);
if constexpr (ES != ttg::ExecutionSpace::Host) {
tt_gemm->set_devicemap(devmap3);
}
#endif // 0

/* Priorities taken from DPLASMA */
Expand Down
Loading

0 comments on commit f9971f5

Please sign in to comment.