Skip to content

Commit

Permalink
Enable backward propagation
Browse files Browse the repository at this point in the history
  • Loading branch information
beomki-yeo committed Dec 2, 2024
1 parent 85171a6 commit a29b4b8
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 18 deletions.
39 changes: 29 additions & 10 deletions core/include/detray/navigation/navigator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,13 @@ class navigator {
/// Scalar representation of the navigation state,
/// @returns distance to next
DETRAY_HOST_DEVICE
scalar_type operator()() const { return target().path; }
scalar_type operator()() const {
if (direction() == navigation::direction::e_forward) {
return target().path;
} else {
return -1.f * target().path;
}
}

/// @returns current volume (index) - const
DETRAY_HOST_DEVICE
Expand Down Expand Up @@ -611,9 +617,14 @@ class navigator {

const auto sf = tracking_surface{det, sf_descr};

auto track_cpy = track;
if (nav_state.direction() == navigation::direction::e_backward) {
track_cpy.set_dir(-1.f * track_cpy.dir());
}

sf.template visit_mask<intersection_initialize<ray_intersector>>(
nav_state, detail::ray(track), sf_descr, det.transform_store(),
ctx,
nav_state, detail::ray(track_cpy), sf_descr,
det.transform_store(), ctx,
sf.is_portal() ? std::array<scalar_type, 2>{0.f, 0.f}
: mask_tol,
mask_tol_scalor, overstep_tol);
Expand Down Expand Up @@ -775,7 +786,8 @@ class navigator {
// - do this only when the navigation state is still coherent
if (navigation.trust_level() == navigation::trust_level::e_high) {
// Update next candidate: If not reachable, 'high trust' is broken
if (!update_candidate(navigation.target(), track, det, cfg, ctx)) {
if (!update_candidate(navigation.direction(), navigation.target(),
track, det, cfg, ctx)) {
navigation.m_status = navigation::status::e_unknown;
navigation.set_fair_trust();
} else {
Expand All @@ -797,7 +809,8 @@ class navigator {

// Else: Track is on module.
// Ready the next candidate after the current module
if (update_candidate(navigation.target(), track, det, cfg,
if (update_candidate(navigation.direction(),
navigation.target(), track, det, cfg,
ctx)) {
return false;
}
Expand All @@ -815,7 +828,8 @@ class navigator {

for (auto &candidate : navigation) {
// Disregard this candidate if it is not reachable
if (!update_candidate(candidate, track, det, cfg, ctx)) {
if (!update_candidate(navigation.direction(), candidate, track,
det, cfg, ctx)) {
// Forcefully set dist to numeric max for sorting
candidate.path = std::numeric_limits<scalar_type>::max();
}
Expand Down Expand Up @@ -897,19 +911,24 @@ class navigator {
/// @returns whether the track can reach this candidate.
template <typename track_t>
DETRAY_HOST_DEVICE inline bool update_candidate(
intersection_type &candidate, const track_t &track,
const detector_type &det, const navigation::config &cfg,
const context_type &ctx) const {
const navigation::direction &nav_dir, intersection_type &candidate,
const track_t &track, const detector_type &det,
const navigation::config &cfg, const context_type &ctx) const {

if (candidate.sf_desc.barcode().is_invalid()) {
return false;
}

const auto sf = tracking_surface{det, candidate.sf_desc};

auto track_cpy = track;
if (nav_dir == navigation::direction::e_backward) {
track_cpy.set_dir(-1.f * track_cpy.dir());
}

// Check whether this candidate is reachable by the track
return sf.template visit_mask<intersection_update<ray_intersector>>(
detail::ray(track), candidate, det.transform_store(), ctx,
detail::ray(track_cpy), candidate, det.transform_store(), ctx,
sf.is_portal() ? std::array<scalar_type, 2>{0.f, 0.f}
: std::array<scalar_type, 2>{cfg.min_mask_tolerance,
cfg.max_mask_tolerance},
Expand Down
1 change: 1 addition & 0 deletions tests/integration_tests/cpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ macro(detray_add_cpu_test algebra)
"builders/material_map_builder.cpp"
"builders/volume_builder.cpp"
"material/material_interaction.cpp"
"propagator/backward_propagation.cpp"
"propagator/covariance_transport.cpp"
"propagator/guided_navigator.cpp"
"propagator/propagator.cpp"
Expand Down
148 changes: 148 additions & 0 deletions tests/integration_tests/cpu/propagator/backward_propagation.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
/** Detray library, part of the ACTS project (R&D line)
*
* (c) 2022-2024 CERN for the benefit of the ACTS project
*
* Mozilla Public License Version 2.0
*/

// Project include(s).
#include "detray/definitions/units.hpp"
#include "detray/detectors/bfield.hpp"
#include "detray/geometry/detail/surface_descriptor.hpp"
#include "detray/geometry/mask.hpp"
#include "detray/geometry/shapes.hpp"
#include "detray/geometry/shapes/unbounded.hpp"
#include "detray/navigation/detail/ray.hpp"
#include "detray/navigation/navigator.hpp"
#include "detray/propagator/actor_chain.hpp"
#include "detray/propagator/actors/parameter_resetter.hpp"
#include "detray/propagator/actors/parameter_transporter.hpp"
#include "detray/propagator/propagator.hpp"
#include "detray/propagator/rk_stepper.hpp"
#include "detray/tracks/tracks.hpp"
#include "detray/utils/axis_rotation.hpp"

// Detray test include(s)
#include "detray/test/utils/detectors/build_telescope_detector.hpp"
#include "detray/test/utils/types.hpp"

// Vecmem include(s)
#include <vecmem/memory/host_memory_resource.hpp>

// google-test include(s).
#include <gtest/gtest.h>

using namespace detray;

// Algebra types
using algebra_t = test::algebra;
using point2 = test::point2;
using vector3 = test::vector3;
using matrix_operator = test::matrix_operator;

constexpr scalar tol{1e-3f};

GTEST_TEST(detray_propagator, backward_propagation) {

vecmem::host_memory_resource host_mr;

// Build in x-direction from given module positions
detail::ray<algebra_t> traj{{0.f, 0.f, 0.f}, 0.f, {1.f, 0.f, 0.f}, -1.f};
std::vector<scalar> positions = {0.f, 10.f, 20.f, 30.f, 40.f, 50.f, 60.f};
/*
std::vector<scalar> positions = {0.f, 100.f, 200.f, 300.f,
400.f, 500.f, 600.f};
*/

tel_det_config<rectangle2D> tel_cfg{200.f * unit<scalar>::mm,
200.f * unit<scalar>::mm};
tel_cfg.positions(positions).pilot_track(traj);

// Build telescope detector with unbounded planes
const auto [det, names] = build_telescope_detector(host_mr, tel_cfg);

// Create b field
using bfield_t = bfield::const_field_t;
vector3 B{1.f * unit<scalar>::T, 1.f * unit<scalar>::T,
1.f * unit<scalar>::T};
const bfield_t hom_bfield = bfield::create_const_field(B);

using navigator_t = navigator<decltype(det)>;
using rk_stepper_t = rk_stepper<bfield_t::view_t, algebra_t>;
using actor_chain_t = actor_chain<dtuple, parameter_transporter<algebra_t>,
parameter_resetter<algebra_t>>;
using propagator_t = propagator<rk_stepper_t, navigator_t, actor_chain_t>;

// Bound vector
bound_parameters_vector<algebra_t> bound_vector{};
bound_vector.set_theta(constant<scalar>::pi_2);
bound_vector.set_qop(-1.f);

// Bound covariance
typename bound_track_parameters<algebra_t>::covariance_type bound_cov =
matrix_operator().template identity<e_bound_size, e_bound_size>();

// Bound track parameter
const bound_track_parameters<algebra_t> bound_param0(
geometry::barcode{}.set_index(0u), bound_vector, bound_cov);

// Actors
parameter_transporter<algebra_t>::state bound_updater{};
parameter_resetter<algebra_t>::state rst{};

propagation::config prop_cfg{};
prop_cfg.stepping.rk_error_tol = 1e-12f * unit<float>::mm;
prop_cfg.navigation.overstep_tolerance = -100.f * unit<float>::um;
propagator_t p{prop_cfg};

// Forward state
propagator_t::state fw_state(bound_param0, hom_bfield, det,
prop_cfg.context);
fw_state.do_debug = true;

// Run propagator
p.propagate(fw_state, detray::tie(bound_updater, rst));

// Print the debug stream
//std::cout << fw_state.debug_stream.str() << std::endl;

// Bound state after propagation
const auto& bound_param1 = fw_state._stepping.bound_params();

// Check if the track reaches the final surface
EXPECT_EQ(bound_param0.surface_link().volume(), 4095u);
EXPECT_EQ(bound_param0.surface_link().index(), 0u);
EXPECT_EQ(bound_param1.surface_link().volume(), 0u);
EXPECT_EQ(bound_param1.surface_link().id(), surface_id::e_sensitive);
EXPECT_EQ(bound_param1.surface_link().index(), 6u);

// Backward state
propagator_t::state bw_state(bound_param1, hom_bfield, det,
prop_cfg.context);
bw_state.do_debug = true;
bw_state._navigation.set_direction(navigation::direction::e_backward);

// Run propagator
p.propagate(bw_state, detray::tie(bound_updater, rst));

// Print the debug stream
//std::cout << bw_state.debug_stream.str() << std::endl;

// Bound state after propagation
const auto& bound_param2 = bw_state._stepping.bound_params();

// Check if the track reaches the initial surface
EXPECT_EQ(bound_param2.surface_link().volume(), 0u);
EXPECT_EQ(bound_param2.surface_link().id(), surface_id::e_sensitive);
EXPECT_EQ(bound_param2.surface_link().index(), 0u);

const auto bound_cov0 = bound_param0.covariance();
const auto bound_cov2 = bound_param2.covariance();
// Check covaraince
for (unsigned int i = 0u; i < e_bound_size; i++) {
for (unsigned int j = 0u; j < e_bound_size; j++) {
EXPECT_NEAR(matrix_operator().element(bound_cov0, i, j),
matrix_operator().element(bound_cov2, i, j), tol);
}
}
}
8 changes: 0 additions & 8 deletions tests/unit_tests/cpu/propagator/covariance_transport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,4 @@ GTEST_TEST(detray_propagator, covariance_transport) {
matrix_operator().element(bound_cov1, i, j), tol);
}
}

// Check covaraince
for (unsigned int i = 0u; i < e_bound_size; i++) {
for (unsigned int j = 0u; j < e_bound_size; j++) {
EXPECT_NEAR(matrix_operator().element(bound_cov0, i, j),
matrix_operator().element(bound_cov1, i, j), tol);
}
}
}

0 comments on commit a29b4b8

Please sign in to comment.