Skip to content

Commit

Permalink
MRA: populate reconstruct output after device has been selected
Browse files Browse the repository at this point in the history
Signed-off-by: Joseph Schuchart <[email protected]>
  • Loading branch information
devreal committed Sep 5, 2024
1 parent 33b7157 commit f6c6352
Showing 1 changed file with 16 additions and 18 deletions.
34 changes: 16 additions & 18 deletions examples/madness/mra-device/mrattg-device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -273,31 +273,13 @@ auto make_reconstruct(
auto r_empty = mra::FunctionReconstructedNode<T,NDIM>(key, K);
r_empty.coeffs.current_view() = T(0.0);
r_empty.is_leaf = false;
// forward() returns a vector that we can push in later
auto sends = ttg::device::forward(ttg::device::send<1>(key, std::move(r_empty)));

/* populate the vector of r's */
std::array<mra::FunctionReconstructedNode<T,NDIM>, key.num_children> r_arr;
for (int i = 0; i < key.num_children; ++i) {
r_arr[i] = mra::FunctionReconstructedNode<T,NDIM>(key, K);
}

/* populate the outputs, we send them at the end
* it's safe to populate them before the kernel is submitted
* since the kernel has no output */
mra::KeyChildren<NDIM> children(key);
for (auto it=children.begin(); it!=children.end(); ++it) {
const mra::Key<NDIM> child= *it;
mra::FunctionReconstructedNode<T,NDIM>& r = r_arr[it.index()];
r.key = child;
r.is_leaf = node.is_child_leaf[it.index()];
if (r.is_leaf) {
sends.push_back(ttg::device::send<1>(child, r));
}
else {
sends.push_back(ttg::device::send<0>(child, r.coeffs));
}
}
// helper lambda to pick apart the std::array
auto do_select = [&]<std::size_t... Is>(std::index_sequence<Is...>){
return ttg::device::select(hg.buffer(), node.coeffs.buffer(), tmp_scratch, (r_arr[Is].coeffs.buffer())...);
Expand All @@ -317,6 +299,22 @@ auto make_reconstruct(
auto from_parent_view = from_parent.current_view();
submit_reconstruct_kernel(key, node_view, hg_view, from_parent_view,
r_ptrs, tmp_scratch.device_ptr(), ttg::device::current_stream());

// forward() returns a vector that we can push into
auto sends = ttg::device::forward(ttg::device::send<1>(key, std::move(r_empty)));
mra::KeyChildren<NDIM> children(key);
for (auto it=children.begin(); it!=children.end(); ++it) {
const mra::Key<NDIM> child= *it;
mra::FunctionReconstructedNode<T,NDIM>& r = r_arr[it.index()];
r.key = child;
r.is_leaf = node.is_child_leaf[it.index()];
if (r.is_leaf) {
sends.push_back(ttg::device::send<1>(child, r));
}
else {
sends.push_back(ttg::device::send<0>(child, r.coeffs));
}
}
#ifndef TTG_ENABLE_HOST
co_await std::move(sends);
#else
Expand Down

0 comments on commit f6c6352

Please sign in to comment.