From 9af4ae8c2dbfcbb231d3bda5154b5131dab717fc Mon Sep 17 00:00:00 2001 From: Joseph Schuchart Date: Thu, 5 Sep 2024 11:26:34 -0400 Subject: [PATCH] MRA: make sure everything is sent through ttg::device::broadcast Also adds ttg::device::broadcastk Signed-off-by: Joseph Schuchart --- examples/madness/mra-device/mrattg-device.cc | 11 ++- ttg/ttg/device/task.h | 97 +++++++++++++++++++- 2 files changed, 100 insertions(+), 8 deletions(-) diff --git a/examples/madness/mra-device/mrattg-device.cc b/examples/madness/mra-device/mrattg-device.cc index 29d6d3691..91be4c2fe 100644 --- a/examples/madness/mra-device/mrattg-device.cc +++ b/examples/madness/mra-device/mrattg-device.cc @@ -44,13 +44,14 @@ auto make_project( using node_type = typename mra::FunctionReconstructedNode; node_type result; tensor_type& coeffs = result.coeffs; + auto outputs = ttg::device::forward(); if (key.level() < initial_level(f)) { std::vector> bcast_keys; /* TODO: children() returns an iteratable object but broadcast() expects a contiguous memory range. We need to fix broadcast to support any ranges */ for (auto child : children(key)) bcast_keys.push_back(child); - ttg::broadcastk<0>(bcast_keys); + outputs.push_back(ttg::device::broadcastk<0>(std::move(bcast_keys))); coeffs.current_view() = T(1e7); // set to obviously bad value to detect incorrect use result.is_leaf = false; } @@ -110,13 +111,15 @@ auto make_project( if (!result.is_leaf) { std::vector> bcast_keys; for (auto child : children(key)) bcast_keys.push_back(child); - ttg::broadcastk<0>(bcast_keys); + outputs.push_back(ttg::device::broadcastk<0>(std::move(bcast_keys))); } } - ttg::send<1>(key, std::move(result)); // always produce a result + outputs.push_back(ttg::device::send<1>(key, std::move(result))); // always produce a result + co_await std::move(outputs); }; - return ttg::make_tt(std::move(fn), ttg::edges(control), ttg::edges(result), "project"); + ttg::Edge, void> refine("refine"); + return ttg::make_tt(std::move(fn), ttg::edges(fuse(control, refine)), ttg::edges(refine,result), "project"); } template diff --git a/ttg/ttg/device/task.h b/ttg/ttg/device/task.h index 8e2d14cfc..76e5b8e8f 100644 --- a/ttg/ttg/device/task.h +++ b/ttg/ttg/device/task.h @@ -270,7 +270,7 @@ namespace ttg::device { /* overload for iterable types that extracts the type of the first element */ template struct broadcast_keylist_trait>> { - using key_type = decltype(*std::begin(std::get<0>(std::declval()))); + using key_type = decltype(*std::begin(std::declval())); }; template (std::tie(kl), std::forward>(value)); } } + + + + /** + * broadcastk + */ + + template + inline void broadcastk(const std::tuple &keylists, + std::tuple...> &t) { + std::get(t).broadcast(std::get(keylists)); + if constexpr (sizeof...(Is) > 0) { + detail::broadcastk(keylists, t); + } + } + + template + inline void broadcastk(const std::tuple &keylists) { + using key_t = typename broadcast_keylist_trait< + std::tuple_element_t...>> + >::key_type; + auto *terminal_ptr = ttg::detail::get_out_terminal(I, "ttg::device::broadcastk(keylists)"); + terminal_ptr->broadcast(std::get(keylists)); + if constexpr (sizeof...(Is) > 0) { + ttg::device::detail::broadcastk(keylists); + } + } + + /* overload with explicit terminals */ + template + inline send_coro_state + broadcastk_coro(RangesT &&keylists, + std::tuple...> &t) { + RangesT kl = std::forward(keylists); // capture the keylist(s) + if constexpr (ttg::meta::is_tuple_v) { + // treat as tuple + co_await ttg::Void{}; // we'll come back once the task is done + ttg::device::detail::broadcastk<0, I, Is...>(kl, t); + } else if constexpr (!ttg::meta::is_tuple_v) { + // create a tie to the captured keylist + co_await ttg::Void{}; // we'll come back once the task is done + ttg::device::detail::broadcastk<0, I, Is...>(std::tie(kl), t); + } + } + + /* overload with implicit terminals */ + template + inline send_coro_state + broadcastk_coro(RangesT &&keylists) { + RangesT kl = std::forward(keylists); // capture the keylist(s) + if constexpr (ttg::meta::is_tuple_v) { + // treat as tuple + static_assert(sizeof...(Is)+1 == std::tuple_size_v, + "Size of keylist tuple must match the number of output terminals"); + co_await ttg::Void{}; // we'll come back once the task is done + ttg::device::detail::broadcastk<0, I, Is...>(kl); + } else if constexpr (!ttg::meta::is_tuple_v) { + // create a tie to the captured keylist + co_await ttg::Void{}; // we'll come back once the task is done + ttg::device::detail::broadcastk<0, I, Is...>(std::tie(kl)); + } + } + } // namespace detail /* overload with explicit terminals and keylist passed by const reference */ @@ -385,15 +453,36 @@ namespace ttg::device { ttg::Runtime Runtime = ttg::ttg_runtime> inline detail::send_t broadcast(rangeT &&keylist, valueT &&value) { ttg::detail::value_copy_handler copy_handler; - return detail::send_t{broadcast_coro(std::tie(keylist), copy_handler(std::forward(value)), - std::move(copy_handler))}; + return detail::send_t{ + detail::broadcast_coro(std::tie(keylist), copy_handler(std::forward(value)), + std::move(copy_handler))}; } + + /* overload with explicit terminals and keylist passed by const reference */ + template + [[nodiscard]] + inline detail::send_t broadcastk(rangeT &&keylist, + std::tuple...> &t) { + ttg::detail::value_copy_handler copy_handler; + return detail::send_t{ + detail::broadcastk_coro(std::forward(keylist), t)}; + } + + /* overload with implicit terminals and keylist passed by const reference */ + template + inline detail::send_t broadcastk(rangeT &&keylist) { + return detail::send_t{detail::broadcastk_coro(std::tie(keylist))}; + } + + template [[nodiscard]] std::vector forward(Args&&... args) { // TODO: check the cost of this! - return std::vector{std::forward(args)...}; + return std::vector{std::forward(args)...}; } /*******************************************