From ed03eefc41ce1d19d03d4f671ab0d3ee01cba4e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ars=C3=A8ne=20P=C3=A9rard-Gayot?= Date: Sat, 17 Feb 2024 19:59:32 +0100 Subject: [PATCH] Add generic traversal function --- src/bvh/v2/bvh.h | 65 ++++++++++++++++++++++++++++++------------------ 1 file changed, 41 insertions(+), 24 deletions(-) diff --git a/src/bvh/v2/bvh.h b/src/bvh/v2/bvh.h index 90fb3f7a..ad939413 100644 --- a/src/bvh/v2/bvh.h +++ b/src/bvh/v2/bvh.h @@ -15,6 +15,7 @@ template struct Bvh { using Index = typename Node::Index; using Scalar = typename Node::Scalar; + using Ray = bvh::v2::Ray; std::vector nodes; std::vector prim_ids; @@ -33,13 +34,21 @@ struct Bvh { /// Extracts the BVH rooted at the given node index. inline Bvh extract_bvh(size_t root_id) const; + /// Traverses the BVH from the given index in `start` using the provided stack. Every leaf + /// encountered on the way is processed using the given `LeafFn` function, and every pair of + /// nodes is processed with the function in `HitFn`, which returns a triplet of booleans + /// indicating whether the first child should be processed, whether the second child should be + /// processed, and whether to traverse the second child first instead of the other way around. + template + inline void traverse(Index start, Stack&, LeafFn&&, InnerFn&&) const; + /// Intersects the BVH with a single ray, using the given function to intersect the contents - /// of a leaf. The algorithm starts at the node index `top` and uses the given stack object. + /// of a leaf. The algorithm starts at the node index `start` and uses the given stack object. /// When `IsAnyHit` is true, the function stops at the first intersection (useful for shadow /// rays), otherwise it finds the closest intersection. When `IsRobust` is true, a slower but /// numerically robust ray-box test is used, otherwise a fast, but less precise test is used. template - inline void intersect(Ray& ray, Index top, Stack&, LeafFn&&, InnerFn&& = {}) const; + inline void intersect(const Ray& ray, Index start, Stack&, LeafFn&&, InnerFn&& = {}) const; inline void serialize(OutputStream&) const; static inline Bvh deserialize(InputStream&); @@ -79,19 +88,9 @@ auto Bvh::extract_bvh(size_t root_id) const -> Bvh { } template -template -void Bvh::intersect(Ray& ray, Index start, Stack& stack, LeafFn&& leaf_fn, InnerFn&& inner_fn) const { - auto inv_dir = ray.template get_inv_dir(); - auto inv_org = -inv_dir * ray.org; - auto inv_dir_pad = Ray::pad_inv_dir(inv_dir); - auto octant = ray.get_octant(); - - auto intersect_node = [&] (const Node& node) { - return IsRobust - ? node.intersect_robust(ray, inv_dir, inv_dir_pad, octant) - : node.intersect_fast(ray, inv_dir, inv_org, octant); - }; - +template +void Bvh::traverse(Index start, Stack& stack, LeafFn&& leaf_fn, InnerFn&& inner_fn) const +{ stack.push(start); restart: while (!stack.is_empty()) { @@ -99,20 +98,13 @@ void Bvh::intersect(Ray& ray, Index start, Stack& while (top.prim_count == 0) { auto& left = nodes[top.first_id]; auto& right = nodes[top.first_id + 1]; - - inner_fn(left, right); - - auto intr_left = intersect_node(left); - auto intr_right = intersect_node(right); - - bool hit_left = intr_left.first <= intr_left.second; - bool hit_right = intr_right.first <= intr_right.second; + auto [hit_left, hit_right, should_swap] = inner_fn(left, right); if (hit_left) { auto near_index = left.index; if (hit_right) { auto far_index = right.index; - if (!IsAnyHit && intr_left.first > intr_right.first) + if (should_swap) std::swap(near_index, far_index); stack.push(far_index); } @@ -130,6 +122,31 @@ void Bvh::intersect(Ray& ray, Index start, Stack& } } +template +template +void Bvh::intersect(const Ray& ray, Index start, Stack& stack, LeafFn&& leaf_fn, InnerFn&& inner_fn) const { + auto inv_dir = ray.template get_inv_dir(); + auto inv_org = -inv_dir * ray.org; + auto inv_dir_pad = ray.pad_inv_dir(inv_dir); + auto octant = ray.get_octant(); + + traverse(start, stack, leaf_fn, [&] (const Node& left, const Node& right) { + inner_fn(left, right); + std::pair intr_left, intr_right; + if constexpr (IsRobust) { + intr_left = left.intersect_robust(ray, inv_dir, inv_dir_pad, octant); + intr_right = right.intersect_robust(ray, inv_dir, inv_org, octant); + } else { + intr_left = left.intersect_fast(ray, inv_dir, inv_org, octant); + intr_right = right.intersect_fast(ray, inv_dir, inv_org, octant); + } + return std::make_tuple( + intr_left.first <= intr_left.second, + intr_right.first <= intr_right.second, + !IsAnyHit && intr_left.first > intr_right.first); + }); +} + template void Bvh::serialize(OutputStream& stream) const { stream.write(nodes.size());