Skip to content

Commit

Permalink
MRA: pass K to make_project
Browse files Browse the repository at this point in the history
Signed-off-by: Joseph Schuchart <[email protected]>
  • Loading branch information
devreal committed Sep 5, 2024
1 parent 9af4ae8 commit 8e4ba72
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions examples/madness/mra-device/mrattg-device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,19 @@ template<typename FnT, typename T, mra::Dimension NDIM>
auto make_project(
mra::Domain<NDIM>& domain,
ttg::Buffer<FnT>& f,
std::size_t K,
const mra::FunctionData<T, NDIM>& functiondata,
const T thresh, /// should be scalar value not complex
ttg::Edge<mra::Key<NDIM>, void> control,
ttg::Edge<mra::Key<NDIM>, mra::FunctionReconstructedNode<T, NDIM>> result)
{
/* create a non-owning buffer for domain and capture it */
auto fn = [&, db = ttg::Buffer<mra::Domain<NDIM>>(&domain), gl = mra::GLbuffer<T>()]
auto fn = [&, K, db = ttg::Buffer<mra::Domain<NDIM>>(&domain), gl = mra::GLbuffer<T>()]
(const mra::Key<NDIM>& key) -> TASKTYPE {
using tensor_type = typename mra::Tensor<T, NDIM>;
using key_type = typename mra::Key<NDIM>;
using node_type = typename mra::FunctionReconstructedNode<T, NDIM>;
node_type result;
auto result = node_type(key, K);
tensor_type& coeffs = result.coeffs;
auto outputs = ttg::device::forward();

Expand Down Expand Up @@ -73,8 +74,6 @@ auto make_project(
const auto& phibar = functiondata.get_phibar();
const auto& hgT = functiondata.get_hgT();

const std::size_t K = coeffs.dim(0);

/* temporaries */
bool is_leaf;
auto is_leaf_scratch = ttg::make_scratch(&is_leaf, ttg::scope::Allocate);
Expand Down Expand Up @@ -354,7 +353,7 @@ void test(std::size_t K) {
// put it into a buffer
auto gauss_buffer = ttg::Buffer<mra::Gaussian<T, NDIM>>(&gaussian);
auto start = make_start(project_control);
auto project = make_project(D, gauss_buffer, functiondata, T(1e-6), project_control, project_result);
auto project = make_project(D, gauss_buffer, K, functiondata, T(1e-6), project_control, project_result);
auto compress = make_compress(functiondata, project_result, compress_result);
auto reconstruct = make_reconstruct(K, functiondata, compress_result, reconstruct_result);
auto printer = make_printer(project_result, "projected ", false);
Expand Down

0 comments on commit 8e4ba72

Please sign in to comment.