Skip to content

Commit

Permalink
Speed up masked sums using multi-resolution
Browse files Browse the repository at this point in the history
  • Loading branch information
victorreijgwart committed Dec 18, 2024
1 parent 3d81e74 commit c19f5bb
Show file tree
Hide file tree
Showing 4 changed files with 206 additions and 76 deletions.
3 changes: 2 additions & 1 deletion examples/cpp/edit/crop_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ int main(int, char**) {
// Crop the map
const Point3D t_W_center{-2.2, -1.4, 0.0};
const FloatingPoint radius = 3.0;
const Sphere cropping_sphere{t_W_center, radius};
auto thread_pool = std::make_shared<ThreadPool>(); // Optional
edit::crop_to_sphere(*map, t_W_center, radius, 0, thread_pool);
edit::crop(*map, cropping_sphere, 0, thread_pool);

// Save the map
const std::filesystem::path output_map_path =
Expand Down
3 changes: 2 additions & 1 deletion examples/cpp/edit/sum_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ int main(int, char**) {
// Crop the map
const Point3D t_W_center{-2.2, -1.4, 0.0};
const FloatingPoint radius = 3.0;
const Sphere cropping_sphere{t_W_center, radius};
auto thread_pool = std::make_shared<ThreadPool>(); // Optional
edit::crop_to_sphere(*map, t_W_center, radius, 0, thread_pool);
edit::crop(*map, cropping_sphere, 0, thread_pool);

// Create a translated copy
Transformation3D T_AB;
Expand Down
239 changes: 175 additions & 64 deletions library/cpp/include/wavemap/core/utils/edit/impl/sum_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,32 @@
#include "wavemap/core/indexing/index_conversions.h"
#include "wavemap/core/utils/iterate/grid_iterator.h"
#include "wavemap/core/utils/shape/aabb.h"
#include "wavemap/core/utils/shape/intersection_tests.h"

namespace wavemap::edit {
namespace detail {
template <typename MapT>
void sumNodeRecursive(
typename MapT::Block::OctreeType::NodeRefType node_A,
typename MapT::Block::OctreeType::NodeConstRefType node_B) {
using NodeRefType = decltype(node_A);
using NodeConstPtrType = typename MapT::Block::OctreeType::NodeConstPtrType;

// Sum
node_A.data() += node_B.data();

// Recursively handle all child nodes
for (NdtreeIndexRelativeChild child_idx = 0;
child_idx < OctreeIndex::kNumChildren; ++child_idx) {
NodeConstPtrType child_node_B = node_B.getChild(child_idx);
if (!child_node_B) {
continue;
}
NodeRefType child_node_A = node_A.getOrAllocateChild(child_idx);
sumNodeRecursive<MapT>(child_node_A, *child_node_B);
}
}

template <typename MapT, typename SamplingFn>
void sumLeavesBatch(typename MapT::Block::OctreeType::NodeRefType node,
const OctreeIndex& node_index, FloatingPoint& node_value,
Expand Down Expand Up @@ -73,29 +96,125 @@ void sumNodeRecursive(typename MapT::Block::OctreeType::NodeRefType node,
node_value = new_value;
}

template <typename MapT>
void sumNodeRecursive(
typename MapT::Block::OctreeType::NodeRefType node_A,
typename MapT::Block::OctreeType::NodeConstRefType node_B) {
using NodeRefType = decltype(node_A);
using NodeConstPtrType = typename MapT::Block::OctreeType::NodeConstPtrType;
template <typename MapT, typename ShapeT>
void sumLeavesBatch(typename MapT::Block::OctreeType::NodeRefType node,
const OctreeIndex& node_index, FloatingPoint& node_value,
ShapeT&& mask, FloatingPoint summand,
FloatingPoint min_cell_width) {
// Decompress child values
using Transform = typename MapT::Block::Transform;
auto& node_details = node.data();
auto child_values = Transform::backward({node_value, {node_details}});

// Sum
node_A.data() += node_B.data();
// Sum all children
for (NdtreeIndexRelativeChild child_idx = 0;
child_idx < OctreeIndex::kNumChildren; ++child_idx) {
const OctreeIndex child_index = node_index.computeChildIndex(child_idx);
const Point3D t_W_child =
convert::nodeIndexToCenterPoint(child_index, min_cell_width);
if (shape::is_inside(t_W_child, mask)) {
child_values[child_idx] += summand;
}
}

// Recursively handle all child nodes
// Compress
const auto [new_value, new_details] =
MapT::Block::Transform::forward(child_values);
node_details = new_details;
node_value = new_value;
}

template <typename MapT, typename ShapeT>
void sumNodeRecursive(typename MapT::Block::OctreeType::NodeRefType node,
const OctreeIndex& node_index, FloatingPoint& node_value,
ShapeT&& mask, FloatingPoint summand,
FloatingPoint min_cell_width,
IndexElement termination_height) {
using NodeRefType = decltype(node);

// Decompress child values
using Transform = typename MapT::Block::Transform;
auto& node_details = node.data();
auto child_values = Transform::backward({node_value, {node_details}});

// Handle each child
for (NdtreeIndexRelativeChild child_idx = 0;
child_idx < OctreeIndex::kNumChildren; ++child_idx) {
NodeConstPtrType child_node_B = node_B.getChild(child_idx);
if (!child_node_B) {
// If the node is fully outside the shape, skip it
const OctreeIndex child_index = node_index.computeChildIndex(child_idx);
const AABB<Point3D> child_aabb =
convert::nodeIndexToAABB(child_index, min_cell_width);
if (!shape::overlaps(child_aabb, mask)) {
continue;
}
NodeRefType child_node_A = node_A.getOrAllocateChild(child_idx);
sumNodeRecursive<MapT>(child_node_A, *child_node_B);

// If the node is fully inside the shape, sum at the current resolution
auto& child_value = child_values[child_idx];
if (shape::is_inside(child_aabb, mask)) {
child_value += summand;
continue;
}

// Otherwise, continue at a higher resolution
NodeRefType child_node = node.getOrAllocateChild(child_idx);
if (child_index.height <= termination_height + 1) {
sumLeavesBatch<MapT>(child_node, child_index, child_value,
std::forward<ShapeT>(mask), summand, min_cell_width);
} else {
sumNodeRecursive<MapT>(child_node, child_index, child_value,
std::forward<ShapeT>(mask), summand,
min_cell_width, termination_height);
}
}

// Compress
const auto [new_value, new_details] = Transform::forward(child_values);
node_details = new_details;
node_value = new_value;
}
} // namespace detail

template <typename MapT>
void sum(MapT& map_A, const MapT& map_B,
const std::shared_ptr<ThreadPool>& thread_pool) {
CHECK_EQ(map_A.getTreeHeight(), map_B.getTreeHeight());
CHECK_EQ(map_A.getMinCellWidth(), map_B.getMinCellWidth());
using NodePtrType = typename MapT::Block::OctreeType::NodePtrType;
using NodeConstPtrType = typename MapT::Block::OctreeType::NodeConstPtrType;

// Process all blocks
map_B.forEachBlock(
[&map_A, &thread_pool](const Index3D& block_index, const auto& block_B) {
auto& block_A = map_A.getOrAllocateBlock(block_index);

// Indicate that the block has changed
block_A.setLastUpdatedStamp();
block_A.setNeedsPruning();

// Sum the blocks' average values (wavelet scale coefficient)
block_A.getRootScale() += block_B.getRootScale();

// Recursively sum all node values (wavelet detail coefficients)
NodePtrType root_node_ptr_A = &block_A.getRootNode();
NodeConstPtrType root_node_ptr_B = &block_B.getRootNode();
if (thread_pool) {
thread_pool->add_task([root_node_ptr_A, root_node_ptr_B,
block_ptr_A = &block_A]() {
detail::sumNodeRecursive<MapT>(*root_node_ptr_A, *root_node_ptr_B);
block_ptr_A->prune();
});
} else {
detail::sumNodeRecursive<MapT>(*root_node_ptr_A, *root_node_ptr_B);
block_A.prune();
}
});

// Wait for all parallel jobs to finish
if (thread_pool) {
thread_pool->wait_all();
}
}

template <typename MapT, typename SamplingFn>
void sum(MapT& map, SamplingFn sampling_function,
const std::unordered_set<Index3D, IndexHash<3>>& block_indices,
Expand All @@ -122,7 +241,7 @@ void sum(MapT& map, SamplingFn sampling_function,
NodePtrType root_node_ptr = &block.getRootNode();
const OctreeIndex root_node_index{tree_height, block_index};

// Recursively crop all nodes
// Recursively sum all nodes
if (thread_pool) {
thread_pool->add_task([root_node_ptr, root_node_index, root_value_ptr,
block_ptr = &block,
Expand All @@ -147,55 +266,14 @@ void sum(MapT& map, SamplingFn sampling_function,
}
}

template <typename MapT>
void sum(MapT& map_A, const MapT& map_B,
const std::shared_ptr<ThreadPool>& thread_pool) {
CHECK_EQ(map_A.getTreeHeight(), map_B.getTreeHeight());
CHECK_EQ(map_A.getMinCellWidth(), map_B.getMinCellWidth());
using NodePtrType = typename MapT::Block::OctreeType::NodePtrType;
using NodeConstPtrType = typename MapT::Block::OctreeType::NodeConstPtrType;

// Process all blocks
map_B.forEachBlock(
[&map_A, &thread_pool](const Index3D& block_index, const auto& block_B) {
auto& block_A = map_A.getOrAllocateBlock(block_index);

// Indicate that the block has changed
block_A.setLastUpdatedStamp();
block_A.setNeedsPruning();

// Sum the blocks' average values (wavelet scale coefficient)
block_A.getRootScale() += block_B.getRootScale();

// Recursively sum all node values (wavelet detail coefficients)
NodePtrType root_node_ptr_A = &block_A.getRootNode();
NodeConstPtrType root_node_ptr_B = &block_B.getRootNode();
if (thread_pool) {
thread_pool->add_task([root_node_ptr_A, root_node_ptr_B,
block_ptr_A = &block_A]() {
detail::sumNodeRecursive<MapT>(*root_node_ptr_A, *root_node_ptr_B);
block_ptr_A->prune();
});
} else {
detail::sumNodeRecursive<MapT>(*root_node_ptr_A, *root_node_ptr_B);
block_A.prune();
}
});

// Wait for all parallel jobs to finish
if (thread_pool) {
thread_pool->wait_all();
}
}

template <typename MapT, typename ShapeT>
void sum(MapT& map, ShapeT shape, FloatingPoint update,
void sum(MapT& map, ShapeT mask, FloatingPoint summand,
const std::shared_ptr<ThreadPool>& thread_pool) {
// Find the blocks that overlap with the shape
const FloatingPoint block_width =
convert::heightToCellWidth(map.getMinCellWidth(), map.getTreeHeight());
const FloatingPoint block_width_inv = 1.f / block_width;
const auto aabb = static_cast<AABB<Point3D>>(shape);
const auto aabb = static_cast<AABB<Point3D>>(mask);
const Index3D block_index_min =
convert::pointToFloorIndex(aabb.min, block_width_inv);
const Index3D block_index_max =
Expand All @@ -205,14 +283,47 @@ void sum(MapT& map, ShapeT shape, FloatingPoint update,
block_indices.emplace(block_index);
}

// Add the update to all cells whose centers lie inside the shape
auto sampling_function = [&shape, update](const Point3D& t_A_p) {
if (shape.contains(t_A_p)) {
return update;
// Make sure all overlapping blocks have been allocated
for (const Index3D& block_index : block_indices) {
map.getOrAllocateBlock(block_index);
}

// Apply the sum to each overlapping block
using NodePtrType = typename MapT::Block::OctreeType::NodePtrType;
const IndexElement tree_height = map.getTreeHeight();
const FloatingPoint min_cell_width = map.getMinCellWidth();
for (const Index3D& block_index : block_indices) {
// Indicate that the block has changed
auto& block = *CHECK_NOTNULL(map.getBlock(block_index));
block.setLastUpdatedStamp();
block.setNeedsPruning();

// Get pointers to the root value and node, which contain the wavelet
// scale and detail coefficients, respectively
FloatingPoint* root_value_ptr = &block.getRootScale();
NodePtrType root_node_ptr = &block.getRootNode();
const OctreeIndex root_node_index{tree_height, block_index};

// Recursively sum all nodes
if (thread_pool) {
thread_pool->add_task([root_node_ptr, root_node_index, root_value_ptr,
block_ptr = &block, &mask, summand,
min_cell_width]() mutable {
detail::sumNodeRecursive<MapT>(*root_node_ptr, root_node_index,
*root_value_ptr, mask, summand,
min_cell_width, 0);
});
} else {
detail::sumNodeRecursive<MapT>(*root_node_ptr, root_node_index,
*root_value_ptr, mask, summand,
min_cell_width, 0);
}
return 0.f;
};
sum(map, sampling_function, block_indices, 0, thread_pool);
}

// Wait for all parallel jobs to finish
if (thread_pool) {
thread_pool->wait_all();
}
}
} // namespace wavemap::edit

Expand Down
37 changes: 27 additions & 10 deletions library/cpp/include/wavemap/core/utils/edit/sum.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,37 +11,54 @@

namespace wavemap::edit {
namespace detail {
// Recursively sum two maps together
template <typename MapT>
void sumNodeRecursive(
typename MapT::Block::OctreeType::NodeRefType node_A,
typename MapT::Block::OctreeType::NodeConstRefType node_B);

// Recursively add a sampled value
template <typename MapT, typename SamplingFn>
void sumLeavesBatch(typename MapT::Block::OctreeType::NodeRefType node,
const OctreeIndex& node_index, FloatingPoint& node_value,
SamplingFn&& sampling_function,
FloatingPoint min_cell_width);

template <typename MapT, typename SamplingFn>
void sumNodeRecursive(typename MapT::Block::OctreeType::NodeRefType node,
const OctreeIndex& node_index, FloatingPoint& node_value,
SamplingFn&& sampling_function,
FloatingPoint min_cell_width,
IndexElement termination_height = 0);

template <typename MapT>
void sumNodeRecursive(
typename MapT::Block::OctreeType::NodeRefType node_A,
typename MapT::Block::OctreeType::NodeConstRefType node_B);
// Recursively add a scalar value to all cells within a given shape
template <typename MapT, typename ShapeT>
void sumLeavesBatch(typename MapT::Block::OctreeType::NodeRefType node,
const OctreeIndex& node_index, FloatingPoint& node_value,
ShapeT&& mask, FloatingPoint summand,
FloatingPoint min_cell_width);
template <typename MapT, typename ShapeT>
void sumNodeRecursive(typename MapT::Block::OctreeType::NodeRefType node,
const OctreeIndex& node_index, FloatingPoint& node_value,
ShapeT&& mask, FloatingPoint summand,
FloatingPoint min_cell_width,
IndexElement termination_height = 0);
} // namespace detail

// Sum two maps together
template <typename MapT>
void sum(MapT& map_A, const MapT& map_B,
const std::shared_ptr<ThreadPool>& thread_pool = nullptr);

// Add a sampled value to all cells within a given list of blocks
template <typename MapT, typename SamplingFn>
void sum(MapT& map, SamplingFn sampling_function,
const std::unordered_set<Index3D, IndexHash<3>>& block_indices,
IndexElement termination_height = 0,
const std::shared_ptr<ThreadPool>& thread_pool = nullptr);

template <typename MapT>
void sum(MapT& map_A, const MapT& map_B,
const std::shared_ptr<ThreadPool>& thread_pool = nullptr);

// Add a scalar value to all cells within a given shape
template <typename MapT, typename ShapeT>
void sum(MapT& map, ShapeT shape, FloatingPoint update,
void sum(MapT& map, ShapeT mask, FloatingPoint summand,
const std::shared_ptr<ThreadPool>& thread_pool = nullptr);
} // namespace wavemap::edit

Expand Down

0 comments on commit c19f5bb

Please sign in to comment.