Skip to content

Commit

Permalink
PaRSEC: Add have_level_zero_op()
Browse files Browse the repository at this point in the history
Signed-off-by: Joseph Schuchart <[email protected]>
  • Loading branch information
devreal committed Nov 16, 2023
1 parent afeda46 commit d1fb80f
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 1 deletion.
1 change: 1 addition & 0 deletions ttg/ttg/make_tt.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ class CallableWrapTTArgs
public:
static constexpr bool have_cuda_op = (space == ttg::ExecutionSpace::CUDA);
static constexpr bool have_hip_op = (space == ttg::ExecutionSpace::HIP);
static constexpr bool have_level_zero_op = (space == ttg::ExecutionSpace::L0);

protected:

Expand Down
32 changes: 31 additions & 1 deletion ttg/ttg/parsec/ttg.h
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,16 @@ namespace ttg_parsec {
}
}

template<typename TT>
inline parsec_hook_return_t hook_level_zero(struct parsec_execution_stream_s *es, parsec_task_t *parsec_task) {
if constexpr(TT::derived_has_level_zero_op()) {
parsec_ttg_task_t<TT> *me = (parsec_ttg_task_t<TT> *)parsec_task;
return me->template invoke_op<ttg::ExecutionSpace::L0>();
} else {
throw std::runtime_error("PaRSEC HIP hook invoked on a TT that does not support HIP operations!");
}
}

template <typename KeyT, typename ActivationCallbackT>
class rma_delayed_activate {
std::vector<KeyT> _keylist;
Expand Down Expand Up @@ -1129,9 +1139,18 @@ namespace ttg_parsec {
}
}

/// @return true if derivedT::have_hip_op exists and is defined to true
static constexpr bool derived_has_level_zero_op() {
if constexpr (ttg::meta::is_detected_v<have_level_zero_op_non_type_t, derivedT>) {
return derivedT::have_level_zero_op;
} else {
return false;
}
}

/// @return true if the TT supports device execution
static constexpr bool derived_has_device_op() {
return (derived_has_cuda_op() || derived_has_hip_op());
return (derived_has_cuda_op() || derived_has_hip_op() || derived_has_level_zero_op());
}

using ttT = TT;
Expand Down Expand Up @@ -3239,6 +3258,8 @@ ttg::abort(); // should not happen
device_supported = !world.impl().mpi_support(ttg::ExecutionSpace::CUDA);
} else if constexpr (derived_has_hip_op()) {
device_supported = !world.impl().mpi_support(ttg::ExecutionSpace::HIP);
} else if constexpr (derived_has_level_zero_op()) {
device_supported = !world.impl().mpi_support(ttg::ExecutionSpace::L0);
}
/* if MPI supports the device we don't care whether we have remote peers
* because we can send from the device directly */
Expand Down Expand Up @@ -3642,6 +3663,15 @@ ttg::abort(); // should not happen
((__parsec_chore_t *)self.incarnations)[0].evaluate = NULL;
((__parsec_chore_t *)self.incarnations)[0].hook = &detail::hook_hip<TT>;

((__parsec_chore_t *)self.incarnations)[1].type = PARSEC_DEV_NONE;
((__parsec_chore_t *)self.incarnations)[1].evaluate = NULL;
((__parsec_chore_t *)self.incarnations)[1].hook = NULL;
} else if (derived_has_level_zero_op()) {
self.incarnations = (__parsec_chore_t *)malloc(3 * sizeof(__parsec_chore_t));
((__parsec_chore_t *)self.incarnations)[0].type = PARSEC_DEV_LEVEL_ZERO;
((__parsec_chore_t *)self.incarnations)[0].evaluate = NULL;
((__parsec_chore_t *)self.incarnations)[0].hook = &detail::hook_level_zero<TT>;

((__parsec_chore_t *)self.incarnations)[1].type = PARSEC_DEV_NONE;
((__parsec_chore_t *)self.incarnations)[1].evaluate = NULL;
((__parsec_chore_t *)self.incarnations)[1].hook = NULL;
Expand Down

0 comments on commit d1fb80f

Please sign in to comment.