Skip to content

Commit

Permalink
Fixes to device hint implementation
Browse files Browse the repository at this point in the history
Signed-off-by: Joseph Schuchart <[email protected]>
  • Loading branch information
devreal committed May 15, 2024
1 parent 4f86f7b commit 8164c44
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 13 deletions.
16 changes: 8 additions & 8 deletions examples/potrf/potrf.h
Original file line number Diff line number Diff line change
Expand Up @@ -683,12 +683,12 @@ namespace potrf {
* data movement between devices. This provides hints for load-balancing up front
* and avoids movement of the TRSM result to GEMM tasks.
*/
auto devmap1 = [&](const key1& key) { return (key[0] / A.P()) % ttg::device::num_devices(); }
auto devmap1 = [&](const Key1& key) { return (key[0] / A.P()) % ttg::device::num_devices(); };

auto devmap2a = [&](const key2& key) { return (key[0] / A.P()) % ttg::device::num_devices(); }
auto devmap2b = [&](const key2& key) { return (key[1] / A.P()) % ttg::device::num_devices(); }
auto devmap2a = [&](const Key2& key) { return (key[0] / A.P()) % ttg::device::num_devices(); };
auto devmap2b = [&](const Key2& key) { return (key[1] / A.P()) % ttg::device::num_devices(); };

auto devmap3 = [&](const key3& key) { return (key[0] / A.P()) % ttg::device::num_devices(); }
auto devmap3 = [&](const Key3& key) { return (key[0] / A.P()) % ttg::device::num_devices(); };

ttg::Edge<Key1, MatrixTile<T>> syrk_potrf("syrk_potrf"), disp_potrf("disp_potrf");

Expand All @@ -705,28 +705,28 @@ namespace potrf {
tt_potrf->set_keymap(keymap1);
tt_potrf->set_defer_writer(defer_write);
#ifdef ENABLE_DEVICE_KERNEL
tt_potrf->set_devmap(devmap1);
tt_potrf->set_devicemap(devmap1);
#endif // 0

auto tt_trsm = make_trsm(A, disp_trsm, potrf_trsm, gemm_trsm, trsm_syrk, trsm_gemm_row, trsm_gemm_col, output);
tt_trsm->set_keymap(keymap2a);
tt_trsm->set_defer_writer(defer_write);
#ifdef ENABLE_DEVICE_KERNEL
tt_trsm->set_devmap(devmap2a);
tt_trsm->set_devicemap(devmap2a);
#endif // 0

auto tt_syrk = make_syrk(A, disp_syrk, trsm_syrk, syrk_syrk, syrk_potrf, syrk_syrk);
tt_syrk->set_keymap(keymap2b);
tt_syrk->set_defer_writer(defer_write);
#ifdef ENABLE_DEVICE_KERNEL
tt_syrk->set_devmap(devmap2b);
tt_syrk->set_devicemap(devmap2b);
#endif // 0

auto tt_gemm = make_gemm(A, disp_gemm, trsm_gemm_row, trsm_gemm_col, gemm_gemm, gemm_trsm, gemm_gemm);
tt_gemm->set_keymap(keymap3);
tt_gemm->set_defer_writer(defer_write);
#ifdef ENABLE_DEVICE_KERNEL
tt_gemm->set_devmap(devmap3);
tt_gemm->set_devicemap(devmap3);
#endif // 0

/* Priorities taken from DPLASMA */
Expand Down
6 changes: 3 additions & 3 deletions ttg/ttg/parsec/ttg.h
Original file line number Diff line number Diff line change
Expand Up @@ -1537,13 +1537,13 @@ namespace ttg_parsec {
if (tt->devicemap) {
int parsec_dev;
if constexpr (std::is_void_v<keyT>) {
parsec_dev = ttg::device::ttg_device_to_parsec_device(tt->devicemap());
parsec_dev = detail::ttg_device_to_parsec_device(tt->devicemap());
} else {
parsec_dev = ttg::device::ttg_device_to_parsec_device(tt->devicemap(tt->key));
parsec_dev = detail::ttg_device_to_parsec_device(tt->devicemap(task->key));
}
for (int i = 0; i < MAX_PARAM_COUNT; ++i) {
/* only set on mutable data since we have exclusive access */
if (tc.in[i].flow_flags & PARSEC_FLOW_ACCESS_WRITE) {
if (tc.in[i]->flow_flags & PARSEC_FLOW_ACCESS_WRITE) {
parsec_data_t *data = parsec_task->data[i].data_in->original;
/* only set the preferred device if the host has the latest copy
* as otherwise we may end up with the wrong data if there is a newer
Expand Down
4 changes: 2 additions & 2 deletions ttg/ttg/util/meta.h
Original file line number Diff line number Diff line change
Expand Up @@ -851,11 +851,11 @@ namespace ttg {
template <typename Key, typename Return, typename Enabler = void>
struct keymap;
template <typename Key, typename Return>
struct keymap<Key, std::enable_if_t<!is_void_v<Key>>> {
struct keymap<Key, Return, std::enable_if_t<!is_void_v<Key>>> {
using type = std::function<Return(const Key &)>;
};
template <typename Key, typename Return>
struct keymap<Key, std::enable_if_t<is_void_v<Key>>> {
struct keymap<Key, Return, std::enable_if_t<is_void_v<Key>>> {
using type = std::function<Return()>;
};
template <typename Key, typename Return = int>
Expand Down

0 comments on commit 8164c44

Please sign in to comment.