Skip to content

Commit

Permalink
MRA: Explicitly pass K to kernels submit for reconstruct
Browse files Browse the repository at this point in the history
Signed-off-by: Joseph Schuchart <[email protected]>
  • Loading branch information
devreal committed Sep 24, 2024
1 parent 5767733 commit ee7e4a5
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 6 deletions.
3 changes: 2 additions & 1 deletion examples/madness/mra-device/kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -542,9 +542,9 @@ void submit_reconstruct_kernel(
const TensorView<T, NDIM>& from_parent,
const std::array<T*, mra::Key<NDIM>::num_children()>& r_arr,
T* tmp,
std::size_t K,
cudaStream_t stream)
{
const std::size_t K = node.dim(0);
/* runs on a single block */
Dim3 thread_dims = Dim3(K, 1, 1); // figure out how to consider register usage
CALL_KERNEL(reconstruct_kernel, 1, thread_dims, 0, stream)(
Expand All @@ -562,4 +562,5 @@ void submit_reconstruct_kernel<double, 3>(
const TensorView<double, 3>& from_parent,
const std::array<double*, Key<3>::num_children()>& r_arr,
double* tmp,
std::size_t K,
cudaStream_t stream);
1 change: 1 addition & 0 deletions examples/madness/mra-device/kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ void submit_reconstruct_kernel(
const mra::TensorView<T, NDIM>& from_parent,
const std::array<T*, mra::Key<NDIM>::num_children()>& r_arr,
T* tmp,
std::size_t K,
ttg::device::Stream stream);

#endif // HAVE_KERNELS_H
10 changes: 5 additions & 5 deletions examples/madness/mra-device/mrattg-device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ static TASKTYPE do_send_leafs_up(const mra::Key<NDIM>& key, const mra::FunctionR
/// Make a composite operator that implements compression for a single function
template <typename T, mra::Dimension NDIM>
static auto make_compress(
const std::size_t K,
const mra::FunctionData<T, NDIM>& functiondata,
ttg::Edge<mra::Key<NDIM>, mra::FunctionReconstructedNode<T, NDIM>>& in,
ttg::Edge<mra::Key<NDIM>, mra::FunctionCompressedNode<T, NDIM>>& out)
Expand All @@ -190,7 +191,7 @@ static auto make_compress(
/* append out edge to set of edges */
auto compress_out_edges = std::tuple_cat(send_to_compress_edges, std::make_tuple(out));
/* use the tuple variant to handle variable number of inputs while suppressing the output tuple */
auto do_compress = [&](const mra::Key<NDIM>& key,
auto do_compress = [&, K](const mra::Key<NDIM>& key,
//const std::tuple<const FunctionReconstructedNodeTypes&...>& input_frns
const mra::FunctionReconstructedNode<T,NDIM> &in0,
const mra::FunctionReconstructedNode<T,NDIM> &in1,
Expand All @@ -204,7 +205,6 @@ static auto make_compress(
//typename ::detail::tree_types<T,K,NDIM>::compress_out_type& out) {
constexpr const auto num_children = mra::Key<NDIM>::num_children();
constexpr const auto out_terminal_id = num_children;
auto K = in0.coeffs.dim(0);
mra::FunctionCompressedNode<T,NDIM> result(key, K); // The eventual result
auto& d = result.coeffs;
// allocate even though we might not need it
Expand Down Expand Up @@ -300,7 +300,7 @@ auto make_reconstruct(
{
ttg::Edge<mra::Key<NDIM>, mra::Tensor<T,NDIM>> S("S"); // passes scaling functions down

auto do_reconstruct = [&](const mra::Key<NDIM>& key,
auto do_reconstruct = [&, K](const mra::Key<NDIM>& key,
mra::FunctionCompressedNode<T, NDIM>&& node,
const mra::Tensor<T, NDIM>& from_parent) -> TASKTYPE {
const std::size_t K = from_parent.dim(0);
Expand Down Expand Up @@ -340,7 +340,7 @@ auto make_reconstruct(
auto hg_view = hg.current_view();
auto from_parent_view = from_parent.current_view();
submit_reconstruct_kernel(key, node_view, hg_view, from_parent_view,
r_ptrs, tmp_scratch.device_ptr(), ttg::device::current_stream());
r_ptrs, tmp_scratch.device_ptr(), K, ttg::device::current_stream());

// forward() returns a vector that we can push into
#ifndef TTG_ENABLE_HOST
Expand Down Expand Up @@ -426,7 +426,7 @@ void test(std::size_t K) {
auto gauss_buffer = ttg::Buffer<mra::Gaussian<T, NDIM>>(&gaussian);
auto start = make_start(project_control);
auto project = make_project(D, gauss_buffer, K, functiondata, T(1e-6), project_control, project_result);
auto compress = make_compress(functiondata, project_result, compress_result);
auto compress = make_compress(K, functiondata, project_result, compress_result);
auto reconstruct = make_reconstruct(K, functiondata, compress_result, reconstruct_result);
auto printer = make_printer(project_result, "projected ", false);
auto printer2 = make_printer(compress_result, "compressed ", false);
Expand Down

0 comments on commit ee7e4a5

Please sign in to comment.