Skip to content

Commit

Permalink
MRA: make sure everything is sent through ttg::device::broadcast
Browse files Browse the repository at this point in the history
Also adds ttg::device::broadcastk

Signed-off-by: Joseph Schuchart <[email protected]>
  • Loading branch information
devreal committed Sep 5, 2024
1 parent 4928703 commit 9af4ae8
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 8 deletions.
11 changes: 7 additions & 4 deletions examples/madness/mra-device/mrattg-device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,14 @@ auto make_project(
using node_type = typename mra::FunctionReconstructedNode<T, NDIM>;
node_type result;
tensor_type& coeffs = result.coeffs;
auto outputs = ttg::device::forward();

if (key.level() < initial_level(f)) {
std::vector<mra::Key<NDIM>> bcast_keys;
/* TODO: children() returns an iteratable object but broadcast() expects a contiguous memory range.
We need to fix broadcast to support any ranges */
for (auto child : children(key)) bcast_keys.push_back(child);
ttg::broadcastk<0>(bcast_keys);
outputs.push_back(ttg::device::broadcastk<0>(std::move(bcast_keys)));
coeffs.current_view() = T(1e7); // set to obviously bad value to detect incorrect use
result.is_leaf = false;
}
Expand Down Expand Up @@ -110,13 +111,15 @@ auto make_project(
if (!result.is_leaf) {
std::vector<mra::Key<NDIM>> bcast_keys;
for (auto child : children(key)) bcast_keys.push_back(child);
ttg::broadcastk<0>(bcast_keys);
outputs.push_back(ttg::device::broadcastk<0>(std::move(bcast_keys)));
}
}
ttg::send<1>(key, std::move(result)); // always produce a result
outputs.push_back(ttg::device::send<1>(key, std::move(result))); // always produce a result
co_await std::move(outputs);
};

return ttg::make_tt<Space>(std::move(fn), ttg::edges(control), ttg::edges(result), "project");
ttg::Edge<mra::Key<NDIM>, void> refine("refine");
return ttg::make_tt<Space>(std::move(fn), ttg::edges(fuse(control, refine)), ttg::edges(refine,result), "project");
}

template<mra::Dimension NDIM, typename Value, std::size_t I, std::size_t... Is>
Expand Down
97 changes: 93 additions & 4 deletions ttg/ttg/device/task.h
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ namespace ttg::device {
/* overload for iterable types that extracts the type of the first element */
template<typename T>
struct broadcast_keylist_trait<T, std::enable_if_t<ttg::meta::is_iterable_v<T>>> {
using key_type = decltype(*std::begin(std::get<0>(std::declval<T>())));
using key_type = decltype(*std::begin(std::declval<T>()));
};

template <size_t KeyId, size_t I, size_t... Is, typename... RangesT, typename valueT,
Expand Down Expand Up @@ -364,6 +364,74 @@ namespace ttg::device {
ttg::device::detail::broadcast<0, I, Is...>(std::tie(kl), std::forward<std::decay_t<decltype(value)>>(value));
}
}



/**
* broadcastk
*/

template <size_t KeyId, size_t I, size_t... Is, typename... RangesT,
typename... out_keysT, typename... out_valuesT>
inline void broadcastk(const std::tuple<RangesT...> &keylists,
std::tuple<ttg::Out<out_keysT, out_valuesT>...> &t) {
std::get<I>(t).broadcast(std::get<KeyId>(keylists));
if constexpr (sizeof...(Is) > 0) {
detail::broadcastk<KeyId+1, Is...>(keylists, t);
}
}

template <size_t KeyId, size_t I, size_t... Is, typename... RangesT,
typename... out_keysT, typename... out_valuesT>
inline void broadcastk(const std::tuple<RangesT...> &keylists) {
using key_t = typename broadcast_keylist_trait<
std::tuple_element_t<KeyId, std::tuple<std::remove_reference_t<RangesT>...>>
>::key_type;
auto *terminal_ptr = ttg::detail::get_out_terminal<key_t, void>(I, "ttg::device::broadcastk(keylists)");
terminal_ptr->broadcast(std::get<KeyId>(keylists));
if constexpr (sizeof...(Is) > 0) {
ttg::device::detail::broadcastk<KeyId+1, Is...>(keylists);
}
}

/* overload with explicit terminals */
template <size_t I, size_t... Is, typename RangesT,
typename... out_keysT, typename... out_valuesT,
ttg::Runtime Runtime = ttg::ttg_runtime>
inline send_coro_state
broadcastk_coro(RangesT &&keylists,
std::tuple<ttg::Out<out_keysT, out_valuesT>...> &t) {
RangesT kl = std::forward<RangesT>(keylists); // capture the keylist(s)
if constexpr (ttg::meta::is_tuple_v<RangesT>) {
// treat as tuple
co_await ttg::Void{}; // we'll come back once the task is done
ttg::device::detail::broadcastk<0, I, Is...>(kl, t);
} else if constexpr (!ttg::meta::is_tuple_v<RangesT>) {
// create a tie to the captured keylist
co_await ttg::Void{}; // we'll come back once the task is done
ttg::device::detail::broadcastk<0, I, Is...>(std::tie(kl), t);
}
}

/* overload with implicit terminals */
template <size_t I, size_t... Is, typename RangesT,
ttg::Runtime Runtime = ttg::ttg_runtime>
inline send_coro_state
broadcastk_coro(RangesT &&keylists) {
RangesT kl = std::forward<RangesT>(keylists); // capture the keylist(s)
if constexpr (ttg::meta::is_tuple_v<RangesT>) {
// treat as tuple
static_assert(sizeof...(Is)+1 == std::tuple_size_v<RangesT>,
"Size of keylist tuple must match the number of output terminals");
co_await ttg::Void{}; // we'll come back once the task is done
ttg::device::detail::broadcastk<0, I, Is...>(kl);
} else if constexpr (!ttg::meta::is_tuple_v<RangesT>) {
// create a tie to the captured keylist
co_await ttg::Void{}; // we'll come back once the task is done
ttg::device::detail::broadcastk<0, I, Is...>(std::tie(kl));
}
}

} // namespace detail

/* overload with explicit terminals and keylist passed by const reference */
Expand All @@ -385,15 +453,36 @@ namespace ttg::device {
ttg::Runtime Runtime = ttg::ttg_runtime>
inline detail::send_t broadcast(rangeT &&keylist, valueT &&value) {
ttg::detail::value_copy_handler<Runtime> copy_handler;
return detail::send_t{broadcast_coro<i>(std::tie(keylist), copy_handler(std::forward<valueT>(value)),
std::move(copy_handler))};
return detail::send_t{
detail::broadcast_coro<i>(std::tie(keylist), copy_handler(std::forward<valueT>(value)),
std::move(copy_handler))};
}


/* overload with explicit terminals and keylist passed by const reference */
template <size_t I, size_t... Is, typename rangeT, typename... out_keysT, typename... out_valuesT,
ttg::Runtime Runtime = ttg::ttg_runtime>
[[nodiscard]]
inline detail::send_t broadcastk(rangeT &&keylist,
std::tuple<ttg::Out<out_keysT, out_valuesT>...> &t) {
ttg::detail::value_copy_handler<Runtime> copy_handler;
return detail::send_t{
detail::broadcastk_coro<I, Is...>(std::forward<rangeT>(keylist), t)};
}

/* overload with implicit terminals and keylist passed by const reference */
template <size_t i, typename rangeT,
ttg::Runtime Runtime = ttg::ttg_runtime>
inline detail::send_t broadcastk(rangeT &&keylist) {
return detail::send_t{detail::broadcastk_coro<i>(std::tie(keylist))};
}


template<typename... Args, ttg::Runtime Runtime = ttg::ttg_runtime>
[[nodiscard]]
std::vector<device::detail::send_t> forward(Args&&... args) {
// TODO: check the cost of this!
return std::vector{std::forward<Args>(args)...};
return std::vector<device::detail::send_t>{std::forward<Args>(args)...};
}

/*******************************************
Expand Down

0 comments on commit 9af4ae8

Please sign in to comment.