From f6c63523f5fa62a5064ff1fc1801a74a840db6a8 Mon Sep 17 00:00:00 2001 From: Joseph Schuchart Date: Thu, 5 Sep 2024 15:19:20 -0400 Subject: [PATCH] MRA: populate reconstruct output after device has been selected Signed-off-by: Joseph Schuchart --- examples/madness/mra-device/mrattg-device.cc | 34 +++++++++----------- 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/examples/madness/mra-device/mrattg-device.cc b/examples/madness/mra-device/mrattg-device.cc index 1c460a97b..a13212156 100644 --- a/examples/madness/mra-device/mrattg-device.cc +++ b/examples/madness/mra-device/mrattg-device.cc @@ -273,8 +273,6 @@ auto make_reconstruct( auto r_empty = mra::FunctionReconstructedNode(key, K); r_empty.coeffs.current_view() = T(0.0); r_empty.is_leaf = false; - // forward() returns a vector that we can push in later - auto sends = ttg::device::forward(ttg::device::send<1>(key, std::move(r_empty))); /* populate the vector of r's */ std::array, key.num_children> r_arr; @@ -282,22 +280,6 @@ auto make_reconstruct( r_arr[i] = mra::FunctionReconstructedNode(key, K); } - /* populate the outputs, we send them at the end - * it's safe to populate them before the kernel is submitted - * since the kernel has no output */ - mra::KeyChildren children(key); - for (auto it=children.begin(); it!=children.end(); ++it) { - const mra::Key child= *it; - mra::FunctionReconstructedNode& r = r_arr[it.index()]; - r.key = child; - r.is_leaf = node.is_child_leaf[it.index()]; - if (r.is_leaf) { - sends.push_back(ttg::device::send<1>(child, r)); - } - else { - sends.push_back(ttg::device::send<0>(child, r.coeffs)); - } - } // helper lambda to pick apart the std::array auto do_select = [&](std::index_sequence){ return ttg::device::select(hg.buffer(), node.coeffs.buffer(), tmp_scratch, (r_arr[Is].coeffs.buffer())...); @@ -317,6 +299,22 @@ auto make_reconstruct( auto from_parent_view = from_parent.current_view(); submit_reconstruct_kernel(key, node_view, hg_view, from_parent_view, r_ptrs, tmp_scratch.device_ptr(), ttg::device::current_stream()); + + // forward() returns a vector that we can push into + auto sends = ttg::device::forward(ttg::device::send<1>(key, std::move(r_empty))); + mra::KeyChildren children(key); + for (auto it=children.begin(); it!=children.end(); ++it) { + const mra::Key child= *it; + mra::FunctionReconstructedNode& r = r_arr[it.index()]; + r.key = child; + r.is_leaf = node.is_child_leaf[it.index()]; + if (r.is_leaf) { + sends.push_back(ttg::device::send<1>(child, r)); + } + else { + sends.push_back(ttg::device::send<0>(child, r.coeffs)); + } + } #ifndef TTG_ENABLE_HOST co_await std::move(sends); #else