Skip to content

Commit

Permalink
Pass the Execution Space as TT template parameter
Browse files Browse the repository at this point in the history
Defaults to Host execution so existing code is not affected.
Properly set by make_tt.

We cannot query TT::derivedT for flags because at the time TT
is instantiated because derivedT is incomplete at that point.
For now pass the Space as template parameter. We need to find
a different way if we want to have multiple implementations
of a task.

Signed-off-by: Joseph Schuchart <[email protected]>
  • Loading branch information
devreal committed Nov 19, 2024
1 parent b180cac commit fda2d8a
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 59 deletions.
4 changes: 3 additions & 1 deletion ttg/ttg/madness/fwd.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

namespace ttg_madness {

template <typename keyT, typename output_terminalsT, typename derivedT, typename input_valueTs = ttg::typelist<>>
template <typename keyT, typename output_terminalsT, typename derivedT,
typename input_valueTs = ttg::typelist<>,
ttg::ExecutionSpace Space = ttg::ExecutionSpace::Host>
class TT;

/// \internal the OG name
Expand Down
5 changes: 3 additions & 2 deletions ttg/ttg/madness/ttg.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,9 @@ namespace ttg_madness {
/// values
/// flowing into this TT; a const type indicates nonmutating (read-only) use, nonconst type
/// indicates mutating use (e.g. the corresponding input can be used as scratch, moved-from, etc.)
template <typename keyT, typename output_terminalsT, typename derivedT, typename input_valueTs>
class TT : public ttg::TTBase, public ::madness::WorldObject<TT<keyT, output_terminalsT, derivedT, input_valueTs>> {
template <typename keyT, typename output_terminalsT, typename derivedT, typename input_valueTs, ttg::ExecutionSpace Space>
class TT : public ttg::TTBase, public ::madness::WorldObject<TT<keyT, output_terminalsT, derivedT, input_valueTs, Space>> {
static_assert(Space == ttg::ExecutionSpace::Host, "MADNESS backend only supports Host Execution Space");
static_assert(ttg::meta::is_typelist_v<input_valueTs>,
"The fourth template for ttg::TT must be a ttg::typelist containing the input types");
using input_tuple_type = ttg::meta::typelist_to_tuple_t<input_valueTs>;
Expand Down
7 changes: 1 addition & 6 deletions ttg/ttg/make_tt.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class CallableWrapTT
: public TT<
keyT, output_terminalsT,
CallableWrapTT<funcT, returnT, funcT_receives_input_tuple, funcT_receives_outterm_tuple, space, keyT, output_terminalsT, input_valuesT...>,
ttg::typelist<input_valuesT...>> {
ttg::typelist<input_valuesT...>, space> {
using baseT = typename CallableWrapTT::ttT;

using input_values_tuple_type = typename baseT::input_values_tuple_type;
Expand All @@ -44,11 +44,6 @@ class CallableWrapTT
void;
#endif // TTG_HAVE_COROUTINE

public:
static constexpr bool have_cuda_op = (space == ttg::ExecutionSpace::CUDA);
static constexpr bool have_hip_op = (space == ttg::ExecutionSpace::HIP);
static constexpr bool have_level_zero_op = (space == ttg::ExecutionSpace::L0);

protected:

template<typename ReturnT>
Expand Down
4 changes: 3 additions & 1 deletion ttg/ttg/parsec/fwd.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ extern "C" struct parsec_context_s;

namespace ttg_parsec {

template <typename keyT, typename output_terminalsT, typename derivedT, typename input_valueTs = ttg::typelist<>>
template <typename keyT, typename output_terminalsT, typename derivedT,
typename input_valueTs = ttg::typelist<>,
ttg::ExecutionSpace Space = ttg::ExecutionSpace::Host>
class TT;

/// \internal the OG name
Expand Down
12 changes: 6 additions & 6 deletions ttg/ttg/parsec/task.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,9 +252,9 @@ namespace ttg_parsec {
template<ttg::ExecutionSpace Space>
parsec_hook_return_t invoke_op() {
if constexpr (Space == ttg::ExecutionSpace::Host) {
return TT::template static_op<Space>(&this->parsec_task);
return TT::static_op(&this->parsec_task);
} else {
return TT::template device_static_op<Space>(&this->parsec_task);
return TT::device_static_op(&this->parsec_task);
}
}

Expand All @@ -263,7 +263,7 @@ namespace ttg_parsec {
if constexpr (Space == ttg::ExecutionSpace::Host) {
return PARSEC_HOOK_RETURN_DONE;
} else {
return TT::template device_static_evaluate<Space>(&this->parsec_task);
return TT::device_static_evaluate(&this->parsec_task);
}
}

Expand Down Expand Up @@ -310,9 +310,9 @@ namespace ttg_parsec {
template<ttg::ExecutionSpace Space>
parsec_hook_return_t invoke_op() {
if constexpr (Space == ttg::ExecutionSpace::Host) {
return TT::template static_op<Space>(&this->parsec_task);
return TT::static_op(&this->parsec_task);
} else {
return TT::template device_static_op<Space>(&this->parsec_task);
return TT::device_static_op(&this->parsec_task);
}
}

Expand All @@ -321,7 +321,7 @@ namespace ttg_parsec {
if constexpr (Space == ttg::ExecutionSpace::Host) {
return PARSEC_HOOK_RETURN_DONE;
} else {
return TT::template device_static_evaluate<Space>(&this->parsec_task);
return TT::device_static_evaluate(&this->parsec_task);
}
}

Expand Down
70 changes: 27 additions & 43 deletions ttg/ttg/parsec/ttg.h
Original file line number Diff line number Diff line change
Expand Up @@ -514,8 +514,9 @@ namespace ttg_parsec {
#endif // TTG_USE_USER_TERMDET
}

template <typename keyT, typename output_terminalsT, typename derivedT, typename input_valueTs = ttg::typelist<>>
void register_tt_profiling(const TT<keyT, output_terminalsT, derivedT, input_valueTs> *t) {
template <typename keyT, typename output_terminalsT, typename derivedT,
typename input_valueTs = ttg::typelist<>, ttg::ExecutionSpace Space>
void register_tt_profiling(const TT<keyT, output_terminalsT, derivedT, input_valueTs, Space> *t) {
#if defined(PARSEC_PROF_TRACE)
std::stringstream ss;
build_composite_name_rec(t->ttg_ptr(), ss);
Expand Down Expand Up @@ -1180,7 +1181,7 @@ namespace ttg_parsec {

} // namespace detail

template <typename keyT, typename output_terminalsT, typename derivedT, typename input_valueTs>
template <typename keyT, typename output_terminalsT, typename derivedT, typename input_valueTs, ttg::ExecutionSpace Space>
class TT : public ttg::TTBase, detail::ParsecTTBase {
private:
/// preconditions
Expand Down Expand Up @@ -1217,29 +1218,17 @@ namespace ttg_parsec {
public:
/// @return true if derivedT::have_cuda_op exists and is defined to true
static constexpr bool derived_has_cuda_op() {
if constexpr (ttg::meta::is_detected_v<have_cuda_op_non_type_t, derivedT>) {
return derivedT::have_cuda_op;
} else {
return false;
}
return Space == ttg::ExecutionSpace::CUDA;
}

/// @return true if derivedT::have_hip_op exists and is defined to true
static constexpr bool derived_has_hip_op() {
if constexpr (ttg::meta::is_detected_v<have_hip_op_non_type_t, derivedT>) {
return derivedT::have_hip_op;
} else {
return false;
}
return Space == ttg::ExecutionSpace::HIP;
}

/// @return true if derivedT::have_hip_op exists and is defined to true
static constexpr bool derived_has_level_zero_op() {
if constexpr (ttg::meta::is_detected_v<have_level_zero_op_non_type_t, derivedT>) {
return derivedT::have_level_zero_op;
} else {
return false;
}
return Space == ttg::ExecutionSpace::L0;
}

/// @return true if the TT supports device execution
Expand Down Expand Up @@ -1354,18 +1343,17 @@ namespace ttg_parsec {
/// dispatches a call to derivedT::op
/// @return void if called a synchronous function, or ttg::coroutine_handle<> if called a coroutine (if non-null,
/// points to the suspended coroutine)
template <ttg::ExecutionSpace Space, typename... Args>
template <typename... Args>
auto op(Args &&...args) {
derivedT *derived = static_cast<derivedT *>(this);
//if constexpr (Space == ttg::ExecutionSpace::Host) {
using return_type = decltype(derived->op(std::forward<Args>(args)...));
if constexpr (std::is_same_v<return_type,void>) {
derived->op(std::forward<Args>(args)...);
return;
}
else {
return derived->op(std::forward<Args>(args)...);
}
using return_type = decltype(derived->op(std::forward<Args>(args)...));
if constexpr (std::is_same_v<return_type,void>) {
derived->op(std::forward<Args>(args)...);
return;
}
else {
return derived->op(std::forward<Args>(args)...);
}
}

template <std::size_t i, typename terminalT, typename Key>
Expand Down Expand Up @@ -1418,7 +1406,6 @@ namespace ttg_parsec {
/**
* Submit callback called by PaRSEC once all input transfers have completed.
*/
template <ttg::ExecutionSpace Space>
static int device_static_submit(parsec_device_gpu_module_t *gpu_device,
parsec_gpu_task_t *gpu_task,
parsec_gpu_exec_stream_t *gpu_stream) {
Expand Down Expand Up @@ -1464,7 +1451,7 @@ namespace ttg_parsec {
#endif // defined(PARSEC_HAVE_DEV_CUDA_SUPPORT) && defined(TTG_HAVE_CUDA)

/* Here we call back into the coroutine again after the transfers have completed */
static_op<Space>(&task->parsec_task);
static_op(&task->parsec_task);

ttg::device::detail::reset_current();

Expand Down Expand Up @@ -1494,7 +1481,6 @@ namespace ttg_parsec {
return rc;
}

template <ttg::ExecutionSpace Space>
static parsec_hook_return_t device_static_evaluate(parsec_task_t* parsec_task) {

task_t *task = (task_t*)parsec_task;
Expand All @@ -1509,7 +1495,7 @@ namespace ttg_parsec {
gpu_task->task_type = 0; // user task
gpu_task->last_data_check_epoch = 0; // used internally
gpu_task->pushout = 0;
gpu_task->submit = &TT::device_static_submit<Space>;
gpu_task->submit = &TT::device_static_submit;

// one way to force the task device
// currently this will probably break all of PaRSEC if this hint
Expand All @@ -1527,7 +1513,7 @@ namespace ttg_parsec {
task->dev_ptr->task_class = *task->parsec_task.task_class;

// first invocation of the coroutine to get the coroutine handle
static_op<Space>(parsec_task);
static_op(parsec_task);

/* when we come back here, the flows in gpu_task are set (see register_device_memory) */

Expand Down Expand Up @@ -1577,7 +1563,6 @@ namespace ttg_parsec {

}

template <ttg::ExecutionSpace Space>
static parsec_hook_return_t device_static_op(parsec_task_t* parsec_task) {
static_assert(derived_has_device_op());

Expand Down Expand Up @@ -1649,7 +1634,6 @@ namespace ttg_parsec {
}
#endif // TTG_HAVE_DEVICE

template <ttg::ExecutionSpace Space>
static parsec_hook_return_t static_op(parsec_task_t *parsec_task) {

task_t *task = (task_t*)parsec_task;
Expand All @@ -1675,14 +1659,14 @@ namespace ttg_parsec {

if constexpr (!ttg::meta::is_void_v<keyT> && !ttg::meta::is_empty_tuple_v<input_values_tuple_type>) {
auto input = make_tuple_of_ref_from_array(task, std::make_index_sequence<numinvals>{});
TTG_PROCESS_TT_OP_RETURN(suspended_task_address, task->coroutine_id, baseobj->template op<Space>(task->key, std::move(input), obj->output_terminals));
TTG_PROCESS_TT_OP_RETURN(suspended_task_address, task->coroutine_id, baseobj->op(task->key, std::move(input), obj->output_terminals));
} else if constexpr (!ttg::meta::is_void_v<keyT> && ttg::meta::is_empty_tuple_v<input_values_tuple_type>) {
TTG_PROCESS_TT_OP_RETURN(suspended_task_address, task->coroutine_id, baseobj->template op<Space>(task->key, obj->output_terminals));
TTG_PROCESS_TT_OP_RETURN(suspended_task_address, task->coroutine_id, baseobj->op(task->key, obj->output_terminals));
} else if constexpr (ttg::meta::is_void_v<keyT> && !ttg::meta::is_empty_tuple_v<input_values_tuple_type>) {
auto input = make_tuple_of_ref_from_array(task, std::make_index_sequence<numinvals>{});
TTG_PROCESS_TT_OP_RETURN(suspended_task_address, task->coroutine_id, baseobj->template op<Space>(std::move(input), obj->output_terminals));
TTG_PROCESS_TT_OP_RETURN(suspended_task_address, task->coroutine_id, baseobj->op(std::move(input), obj->output_terminals));
} else if constexpr (ttg::meta::is_void_v<keyT> && ttg::meta::is_empty_tuple_v<input_values_tuple_type>) {
TTG_PROCESS_TT_OP_RETURN(suspended_task_address, task->coroutine_id, baseobj->template op<Space>(obj->output_terminals));
TTG_PROCESS_TT_OP_RETURN(suspended_task_address, task->coroutine_id, baseobj->op(obj->output_terminals));
} else {
ttg::abort();
}
Expand Down Expand Up @@ -1758,7 +1742,6 @@ namespace ttg_parsec {
return PARSEC_HOOK_RETURN_DONE;
}

template <ttg::ExecutionSpace Space>
static parsec_hook_return_t static_op_noarg(parsec_task_t *parsec_task) {
task_t *task = static_cast<task_t*>(parsec_task);

Expand All @@ -1774,9 +1757,9 @@ namespace ttg_parsec {
assert(detail::parsec_ttg_caller == NULL);
detail::parsec_ttg_caller = task;
if constexpr (!ttg::meta::is_void_v<keyT>) {
TTG_PROCESS_TT_OP_RETURN(suspended_task_address, task->coroutine_id, baseobj->template op<Space>(task->key, obj->output_terminals));
TTG_PROCESS_TT_OP_RETURN(suspended_task_address, task->coroutine_id, baseobj->op(task->key, obj->output_terminals));
} else if constexpr (ttg::meta::is_void_v<keyT>) {
TTG_PROCESS_TT_OP_RETURN(suspended_task_address, task->coroutine_id, baseobj->template op<Space>(obj->output_terminals));
TTG_PROCESS_TT_OP_RETURN(suspended_task_address, task->coroutine_id, baseobj->op(obj->output_terminals));
} else // unreachable
ttg:: abort();
detail::parsec_ttg_caller = NULL;
Expand Down Expand Up @@ -4330,7 +4313,7 @@ namespace ttg_parsec {
void make_executable() override {
world.impl().register_tt_profiling(this);
register_static_op_function();
ttg::TTBase::make_executable();
::ttg::TTBase::make_executable();
}

/// keymap accessor
Expand Down Expand Up @@ -4376,6 +4359,7 @@ namespace ttg_parsec {
return ttg::device::Device(dm(key), ttg::ExecutionSpace::L0);
} else {
throw std::runtime_error("Unknown device type!");
return ttg::device::Device{};
}
};
}
Expand Down

0 comments on commit fda2d8a

Please sign in to comment.