From 8e4ba72202691f639514d405624c419eed68d62b Mon Sep 17 00:00:00 2001 From: Joseph Schuchart Date: Thu, 5 Sep 2024 11:42:22 -0400 Subject: [PATCH] MRA: pass K to make_project Signed-off-by: Joseph Schuchart --- examples/madness/mra-device/mrattg-device.cc | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/examples/madness/mra-device/mrattg-device.cc b/examples/madness/mra-device/mrattg-device.cc index 91be4c2fe..b0b069d4b 100644 --- a/examples/madness/mra-device/mrattg-device.cc +++ b/examples/madness/mra-device/mrattg-device.cc @@ -31,18 +31,19 @@ template auto make_project( mra::Domain& domain, ttg::Buffer& f, + std::size_t K, const mra::FunctionData& functiondata, const T thresh, /// should be scalar value not complex ttg::Edge, void> control, ttg::Edge, mra::FunctionReconstructedNode> result) { /* create a non-owning buffer for domain and capture it */ - auto fn = [&, db = ttg::Buffer>(&domain), gl = mra::GLbuffer()] + auto fn = [&, K, db = ttg::Buffer>(&domain), gl = mra::GLbuffer()] (const mra::Key& key) -> TASKTYPE { using tensor_type = typename mra::Tensor; using key_type = typename mra::Key; using node_type = typename mra::FunctionReconstructedNode; - node_type result; + auto result = node_type(key, K); tensor_type& coeffs = result.coeffs; auto outputs = ttg::device::forward(); @@ -73,8 +74,6 @@ auto make_project( const auto& phibar = functiondata.get_phibar(); const auto& hgT = functiondata.get_hgT(); - const std::size_t K = coeffs.dim(0); - /* temporaries */ bool is_leaf; auto is_leaf_scratch = ttg::make_scratch(&is_leaf, ttg::scope::Allocate); @@ -354,7 +353,7 @@ void test(std::size_t K) { // put it into a buffer auto gauss_buffer = ttg::Buffer>(&gaussian); auto start = make_start(project_control); - auto project = make_project(D, gauss_buffer, functiondata, T(1e-6), project_control, project_result); + auto project = make_project(D, gauss_buffer, K, functiondata, T(1e-6), project_control, project_result); auto compress = make_compress(functiondata, project_result, compress_result); auto reconstruct = make_reconstruct(K, functiondata, compress_result, reconstruct_result); auto printer = make_printer(project_result, "projected ", false);