Skip to content

Commit

Permalink
Add device hint to TT and buffer
Browse files Browse the repository at this point in the history
For POTRF, we want to provide a hint that tasks on the same column
should be executed on the same device, to reduce data movement
and provide a hint on load balancing up front.

Signed-off-by: Joseph Schuchart <[email protected]>
  • Loading branch information
devreal committed May 15, 2024
1 parent 2a66007 commit 4f86f7b
Show file tree
Hide file tree
Showing 10 changed files with 135 additions and 8 deletions.
10 changes: 9 additions & 1 deletion examples/potrf/pmw.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,14 @@ class PaRSECMatrixWrapper {
(pm->uplo == PARSEC_MATRIX_UPPER && col >= row);
}

int P() const {
return pm->grid.rows;
}

int Q() const {
return pm->grid.cols;
}

PaRSECMatrixT* parsec() {
return pm;
}
Expand Down Expand Up @@ -132,7 +140,7 @@ class PaRSECMatrixWrapper {
};

template<typename ValueT>
using MatrixT = PaRSECMatrixWrapper<sym_two_dim_block_cyclic_t, ValueT>;
using MatrixT = PaRSECMatrixWrapper<parsec_matrix_sym_block_cyclic_t, ValueT>;

static auto make_load_tt(MatrixT<double> &A, ttg::Edge<Key2, MatrixTile<double>> &toop, bool defer_write)
{
Expand Down
24 changes: 24 additions & 0 deletions examples/potrf/potrf.h
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,18 @@ namespace potrf {

auto keymap3 = [&](const Key3& key) { return A.rank_of(key[0], key[1]); };

/**
* Device map hints: we try to keep tiles on one row on the same device to minimize
* data movement between devices. This provides hints for load-balancing up front
* and avoids movement of the TRSM result to GEMM tasks.
*/
auto devmap1 = [&](const key1& key) { return (key[0] / A.P()) % ttg::device::num_devices(); }

auto devmap2a = [&](const key2& key) { return (key[0] / A.P()) % ttg::device::num_devices(); }
auto devmap2b = [&](const key2& key) { return (key[1] / A.P()) % ttg::device::num_devices(); }

auto devmap3 = [&](const key3& key) { return (key[0] / A.P()) % ttg::device::num_devices(); }

ttg::Edge<Key1, MatrixTile<T>> syrk_potrf("syrk_potrf"), disp_potrf("disp_potrf");

ttg::Edge<Key2, MatrixTile<T>> potrf_trsm("potrf_trsm"), trsm_syrk("trsm_syrk"), gemm_trsm("gemm_trsm"),
Expand All @@ -692,18 +704,30 @@ namespace potrf {
auto tt_potrf = make_potrf(A, disp_potrf, syrk_potrf, potrf_trsm, output);
tt_potrf->set_keymap(keymap1);
tt_potrf->set_defer_writer(defer_write);
#ifdef ENABLE_DEVICE_KERNEL
tt_potrf->set_devmap(devmap1);
#endif // 0

auto tt_trsm = make_trsm(A, disp_trsm, potrf_trsm, gemm_trsm, trsm_syrk, trsm_gemm_row, trsm_gemm_col, output);
tt_trsm->set_keymap(keymap2a);
tt_trsm->set_defer_writer(defer_write);
#ifdef ENABLE_DEVICE_KERNEL
tt_trsm->set_devmap(devmap2a);
#endif // 0

auto tt_syrk = make_syrk(A, disp_syrk, trsm_syrk, syrk_syrk, syrk_potrf, syrk_syrk);
tt_syrk->set_keymap(keymap2b);
tt_syrk->set_defer_writer(defer_write);
#ifdef ENABLE_DEVICE_KERNEL
tt_syrk->set_devmap(devmap2b);
#endif // 0

auto tt_gemm = make_gemm(A, disp_gemm, trsm_gemm_row, trsm_gemm_col, gemm_gemm, gemm_trsm, gemm_gemm);
tt_gemm->set_keymap(keymap3);
tt_gemm->set_defer_writer(defer_write);
#ifdef ENABLE_DEVICE_KERNEL
tt_gemm->set_devmap(devmap3);
#endif // 0

/* Priorities taken from DPLASMA */
auto nt = A.cols();
Expand Down
7 changes: 7 additions & 0 deletions ttg/ttg/device/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "ttg/config.h"
#include "ttg/execution.h"
#include "ttg/impl_selector.h"



Expand Down Expand Up @@ -180,3 +181,9 @@ namespace ttg::device {
}
} // namespace ttg
#endif // defined(TTG_HAVE_HIP)

namespace ttg::device {
inline int num_devices() {
return TTG_IMPL_NS::num_devices();
}
}
9 changes: 9 additions & 0 deletions ttg/ttg/madness/device.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#ifndef TTG_MADNESS_DEVICE_H
#define TTG_MADNESS_DEVICE_H

namespace ttg_madness {
/* no device support in MADNESS */
inline int num_devices() { return 0; }
}

#endif // TTG_MADNESS_DEVICE_H
1 change: 1 addition & 0 deletions ttg/ttg/madness/ttg.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "ttg/base/keymap.h"
#include "ttg/base/tt.h"
#include "ttg/func.h"
#include "ttg/madness/device.h"
#include "ttg/runtimes.h"
#include "ttg/tt.h"
#include "ttg/util/bug.h"
Expand Down
8 changes: 8 additions & 0 deletions ttg/ttg/parsec/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,14 @@ struct Buffer : public detail::ttg_parsec_data_wrapper_t
// << " parsec_data " << m_data.get() << std::endl;
}

void prefer_device(ttg::device::Device dev) {
/* only set device if the host has the latest copy as otherwise we might end up with a stale copy */
if (dev.is_device() && this->parsec_data()->owner_device == 0) {
parsec_advise_data_on_device(this->parsec_data(), detail::ttg_device_to_parsec_device(dev),
PARSEC_DEV_DATA_ADVICE_PREFERRED_DEVICE);
}
}

/* serialization support */

#ifdef TTG_SERIALIZATION_SUPPORTS_BOOST
Expand Down
7 changes: 7 additions & 0 deletions ttg/ttg/parsec/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define TTG_PARSEC_DEVICE_H

#include "ttg/device/device.h"
#include <parsec/mca/device/device.h>

namespace ttg_parsec {

Expand Down Expand Up @@ -35,6 +36,12 @@ namespace ttg_parsec {
}
} // namespace detail


inline
int num_devices() {
return parsec_nb_devices - detail::first_device_id;
}

} // namespace ttg_parsec

#endif // TTG_PARSEC_DEVICE_H
2 changes: 2 additions & 0 deletions ttg/ttg/parsec/fwd.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ namespace ttg_parsec {
template<typename... Buffer>
inline void mark_device_out(std::tuple<Buffer&...> &b);

inline int num_devices();

#if 0
template<typename... Args>
inline std::pair<bool, std::tuple<ptr<std::decay_t<Args>>...>> get_ptr(Args&&... args);
Expand Down
61 changes: 61 additions & 0 deletions ttg/ttg/parsec/ttg.h
Original file line number Diff line number Diff line change
Expand Up @@ -1296,6 +1296,7 @@ namespace ttg_parsec {
ttg::World world;
ttg::meta::detail::keymap_t<keyT> keymap;
ttg::meta::detail::keymap_t<keyT> priomap;
ttg::meta::detail::keymap_t<keyT, ttg::device::Device> devicemap;
// For now use same type for unary/streaming input terminals, and stream reducers assigned at runtime
ttg::meta::detail::input_reducers_t<actual_input_tuple_type>
input_reducers; //!< Reducers for the input terminals (empty = expect single value)
Expand Down Expand Up @@ -1502,6 +1503,12 @@ namespace ttg_parsec {
gpu_task->pushout = 0;
gpu_task->submit = &TT::device_static_submit<Space>;

// one way to force the task device
// currently this will probably break all of PaRSEC if this hint
// does not match where the data is located, not really useful for us
// instead we set a hint on the data if there is no hint set yet
//parsec_task->selected_device = ...;

/* set the gpu_task so it's available in register_device_memory */
task->dev_ptr->gpu_task = gpu_task;

Expand All @@ -1525,6 +1532,29 @@ namespace ttg_parsec {
}
tc.nb_flows = MAX_PARAM_COUNT;

/* set the device hint on the data */
TT *tt = task->tt;
if (tt->devicemap) {
int parsec_dev;
if constexpr (std::is_void_v<keyT>) {
parsec_dev = ttg::device::ttg_device_to_parsec_device(tt->devicemap());
} else {
parsec_dev = ttg::device::ttg_device_to_parsec_device(tt->devicemap(tt->key));
}
for (int i = 0; i < MAX_PARAM_COUNT; ++i) {
/* only set on mutable data since we have exclusive access */
if (tc.in[i].flow_flags & PARSEC_FLOW_ACCESS_WRITE) {
parsec_data_t *data = parsec_task->data[i].data_in->original;
/* only set the preferred device if the host has the latest copy
* as otherwise we may end up with the wrong data if there is a newer
* version on a different device. Also, keep fingers crossed. */
if (data->owner_device == 0) {
parsec_advise_data_on_device(data, parsec_dev, PARSEC_DEV_DATA_ADVICE_PREFERRED_DEVICE);
}
}
}
}

/* set the new task class that contains the flows */
task->parsec_task.task_class = &task->dev_ptr->task_class;

Expand Down Expand Up @@ -4195,6 +4225,37 @@ ttg::abort(); // should not happen
priomap = std::forward<Priomap>(pm);
}

/// device map setter
/// The device map provides a hint on which device a task should execute.
/// TTG may not be able to honor the request and the corresponding task
/// may execute on a different device.
/// @arg pm a function that provides a hint on which device the task should execute.
template<typename Devicemap>
void set_devicemap(Devicemap&& dm) {
static_assert(derived_has_device_op(), "Device map only allowed on device-enabled TT!");
if constexpr (std::is_same_v<ttg::device::Device, decltype(dm(std::declval<keyT>()))>) {
// dm returns a Device
devicemap = std::forward<Devicemap>(dm);
} else {
// convert dm return into a Device
devicemap = [=](const keyT& key) {
if constexpr (derived_has_cuda_op()) {
return ttg::device::Device(dm(key), ttg::ExecutionSpace::CUDA);
} else if constexpr (derived_has_hip_op()) {
return ttg::device::Device(dm(key), ttg::ExecutionSpace::HIP);
} else if constexpr (derived_has_level_zero_op()) {
return ttg::device::Device(dm(key), ttg::ExecutionSpace::L0);
} else {
throw std::runtime_error("Unknown device type!");
}
};
}
}

/// device map accessor
/// @return the device map
auto get_devicemap() { return devicemap; }

// Register the static_op function to associate it to instance_id
void register_static_op_function(void) {
int rank;
Expand Down
14 changes: 7 additions & 7 deletions ttg/ttg/util/meta.h
Original file line number Diff line number Diff line change
Expand Up @@ -848,18 +848,18 @@ namespace ttg {
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// keymap_t<key,value> = std::function<int(const key&>, protected against void key
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Key, typename Enabler = void>
template <typename Key, typename Return, typename Enabler = void>
struct keymap;
template <typename Key>
template <typename Key, typename Return>
struct keymap<Key, std::enable_if_t<!is_void_v<Key>>> {
using type = std::function<int(const Key &)>;
using type = std::function<Return(const Key &)>;
};
template <typename Key>
template <typename Key, typename Return>
struct keymap<Key, std::enable_if_t<is_void_v<Key>>> {
using type = std::function<int()>;
using type = std::function<Return()>;
};
template <typename Key>
using keymap_t = typename keymap<Key>::type;
template <typename Key, typename Return = int>
using keymap_t = typename keymap<Key, Return>::type;

////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// input_reducers_t<valueTs...> = std::tuple<
Expand Down

0 comments on commit 4f86f7b

Please sign in to comment.