Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Put detector in navigator_state #293

Merged
merged 13 commits into from
Sep 27, 2022
59 changes: 30 additions & 29 deletions core/include/detray/propagator/navigator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,16 +109,26 @@ class navigator {
typename vector_type<intersection_type>::const_iterator;

public:
using detector_type = navigator::detector_type;

/// Default constructor
state() = default;

state(const detector_type &det) : _detector(&det) {}

/// Constructor with memory resource
DETRAY_HOST
state(vecmem::memory_resource &resource) : _candidates(&resource) {}
state(const detector_type &det, vecmem::memory_resource &resource)
: _detector(&det), _candidates(&resource) {}

/// Constructor from candidates vector_view
DETRAY_HOST_DEVICE state(vector_type<intersection_type> candidates)
: _candidates(candidates) {}
DETRAY_HOST_DEVICE state(const detector_type &det,
vector_type<intersection_type> candidates)
: _detector(&det), _candidates(candidates) {}

/// @returns a pointer of detector
DETRAY_HOST_DEVICE
auto detector() const { return _detector; }

/// Scalar representation of the navigation state,
/// @returns distance to next
Expand Down Expand Up @@ -357,6 +367,9 @@ class navigator {
/// Heartbeat of this navigation flow signals navigation is alive
bool _heartbeat = false;

/// Detector pointer
const detector_type *const _detector;

/// Our cache of candidates (intersections with any kind of surface)
vector_type<intersection_type> _candidates = {};

Expand Down Expand Up @@ -387,16 +400,6 @@ class navigator {
dindex _volume_index = 0;
};

/// Constructor from detector object, which is not owned by the navigator
/// and needs to be guaranteed to have a lifetime beyond that of the
/// navigator
DETRAY_HOST_DEVICE
navigator(const detector_t &d) : _detector(&d) {}

/// @returns reference to the detector
DETRAY_HOST_DEVICE
const detector_t &get_detector() const { return *_detector; }

/// Helper method to initialize a volume.
///
/// Calls the volumes accelerator structure for local navigation, then tests
Expand All @@ -410,8 +413,9 @@ class navigator {
DETRAY_HOST_DEVICE inline bool init(propagator_state_t &propagation) const {

state &navigation = propagation._navigation;
const auto det = navigation.detector();
const auto &track = propagation._stepping();
const auto &volume = _detector->volume_by_index(navigation.volume());
const auto &volume = det->volume_by_index(navigation.volume());

// Clean up state
navigation.clear();
Expand All @@ -421,11 +425,10 @@ class navigator {

// Loop over all indexed objects in volume, intersect and fill
// @todo - will come from the local object finder
const auto &tf_store = _detector->transform_store();
const auto &mask_store = _detector->mask_store();
const auto &tf_store = det->transform_store();
const auto &mask_store = det->mask_store();

for (const auto [obj_idx, obj] :
enumerate(_detector->surfaces(), volume)) {
for (const auto [obj_idx, obj] : enumerate(det->surfaces(), volume)) {

std::size_t count =
mask_store.template execute<intersection_initialize>(
Expand Down Expand Up @@ -526,6 +529,7 @@ class navigator {
propagator_state_t &propagation) const {

state &navigation = propagation._navigation;
const auto det = navigation.detector();
const auto &track = propagation._stepping();

// Current candidates are up to date, nothing left to do
Expand All @@ -539,7 +543,7 @@ class navigator {
navigation.n_candidates() == 1) {

// Update next candidate: If not reachable, 'high trust' is broken
if (not update_candidate(*navigation.next(), track)) {
if (not update_candidate(*navigation.next(), track, det)) {
navigation.set_state(navigation::status::e_unknown,
dindex_invalid,
navigation::trust_level::e_no_trust);
Expand All @@ -565,7 +569,7 @@ class navigator {

// Else: Track is on module.
// Ready the next candidate after the current module
if (update_candidate(*navigation.next(), track)) {
if (update_candidate(*navigation.next(), track, det)) {
return;
}

Expand All @@ -581,7 +585,7 @@ class navigator {

for (auto &candidate : navigation.candidates()) {
// Disregard this candidate if it is not reachable
if (not update_candidate(candidate, track)) {
if (not update_candidate(candidate, track, det)) {
// Forcefully set dist to numeric max for sorting
candidate.path = std::numeric_limits<scalar>::max();
}
Expand Down Expand Up @@ -671,15 +675,15 @@ class navigator {
/// @returns whether the track can reach this candidate.
template <typename track_t>
DETRAY_HOST_DEVICE inline bool update_candidate(
intersection_type &candidate, const track_t &track) const {
intersection_type &candidate, const track_t &track,
const detector_type *det) const {
// Remember the surface this candidate belongs to
const dindex obj_idx = candidate.index;

const auto &mask_store = _detector->mask_store();
const auto &sf = _detector->surface_by_index(obj_idx);
const auto &mask_store = det->mask_store();
const auto &sf = det->surface_by_index(obj_idx);
candidate = mask_store.template execute<intersection_update>(
sf.mask_type(), detail::ray(track), sf,
_detector->transform_store());
sf.mask_type(), detail::ray(track), sf, det->transform_store());

candidate.index = obj_idx;
// Check whether this candidate is reachable by the track
Expand All @@ -702,9 +706,6 @@ class navigator {
return detail::find_if(candidates.begin(), candidates.end(),
not_reachable);
}

/// the containers for all data
const detector_t *const _detector;
};

/// @return the vecmem jagged vector buffer for surface candidates
Expand Down
8 changes: 5 additions & 3 deletions core/include/detray/propagator/propagator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,20 +59,22 @@ struct propagator {
/// @param candidates buffer for intersections in the navigator
DETRAY_HOST_DEVICE state(
const free_track_parameters_type &t_in,
const typename navigator_t::detector_type &det,
typename actor_chain_t::state actor_states = {},
vector_type<line_plane_intersection> &&candidates = {})
: _stepping(t_in),
_navigation(std::move(candidates)),
_navigation(det, std::move(candidates)),
_actor_states(actor_states) {}

/// Construct the propagation state with bound parameter
DETRAY_HOST_DEVICE state(
const bound_track_parameters_type &param,
const transform3_type &trf3,
const typename stepper_t::transform3_type &trf3,
const typename navigator_t::detector_type &det,
typename actor_chain_t::state actor_states = {},
vector_type<line_plane_intersection> &&candidates = {})
: _stepping(param, trf3),
_navigation(std::move(candidates)),
_navigation(det, std::move(candidates)),
_actor_states(actor_states) {}

// Is the propagation still alive?
Expand Down
4 changes: 2 additions & 2 deletions tests/benchmarks/cuda/benchmark_propagator_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ static void BM_PROPAGATOR_CPU(benchmark::State &state) {
rk_stepper_type s(B_field);

// Create navigator
navigator_host_type n(det);
navigator_host_type n;

// Create propagator
propagator_host_type p(std::move(s), std::move(n));
Expand All @@ -75,7 +75,7 @@ static void BM_PROPAGATOR_CPU(benchmark::State &state) {
for (auto &track : tracks) {

// Create the propagator state
propagator_host_type::state p_state(track);
propagator_host_type::state p_state(track, det);

// Run propagation
p.propagate(p_state);
Expand Down
4 changes: 2 additions & 2 deletions tests/benchmarks/cuda/benchmark_propagator_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@ __global__ void propagator_benchmark_kernel(
rk_stepper_type s(B_field);

// Create navigator
navigator_device_type n(det);
navigator_device_type n;

// Create propagator
propagator_device_type p(std::move(s), std::move(n));

// Create the propagator state
propagator_device_type::state p_state(
tracks.at(gid), actor_chain<>::state{}, candidates.at(gid));
tracks.at(gid), det, actor_chain<>::state{}, candidates.at(gid));

// Run propagation
p.propagate(p_state);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ TEST(ALGEBRA_PLUGIN, straight_line_navigation) {
using propagator_t = propagator<stepper_t, navigator_t, actor_chain<>>;

// Propagator
propagator_t prop(stepper_t{}, navigator_t{det});
propagator_t prop(stepper_t{}, navigator_t{});

constexpr std::size_t theta_steps{50};
constexpr std::size_t phi_steps{50};
Expand All @@ -75,7 +75,7 @@ TEST(ALGEBRA_PLUGIN, straight_line_navigation) {
// Now follow that ray with a track and check, if we find the same
// volumes and distances along the way
free_track_parameters_type track(ray.pos(), 0, ray.dir(), -1);
propagator_t::state propagation(track);
propagator_t::state propagation(track, det);

// Retrieve navigation information
auto &inspector = propagation._navigation.inspector();
Expand Down Expand Up @@ -141,7 +141,7 @@ TEST(ALGEBRA_PLUGIN, helix_navigation) {
const vector3 B{0. * unit_constants::T, 0. * unit_constants::T,
2. * unit_constants::T};
b_field_t b_field(B);
propagator_t prop(stepper_t{b_field}, navigator_t{det});
propagator_t prop(stepper_t{b_field}, navigator_t{});

constexpr std::size_t theta_steps{10};
constexpr std::size_t phi_steps{10};
Expand All @@ -168,7 +168,7 @@ TEST(ALGEBRA_PLUGIN, helix_navigation) {

// Now follow that helix with the same track and check, if we find
// the same volumes and distances along the way
propagator_t::state propagation(track);
propagator_t::state propagation(track, det);

// Retrieve navigation information
auto &inspector = propagation._navigation.inspector();
Expand Down
21 changes: 12 additions & 9 deletions tests/common/include/tests/common/test_telescope_detector.inl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ struct prop_state {
navigation_t _navigation;

template <typename track_t>
prop_state(const track_t &t_in) : _stepping(t_in) {}
prop_state(const track_t &t_in,
const typename navigation_t::detector_type &det)
: _stepping(t_in), _navigation(det) {}
};

} // anonymous namespace
Expand Down Expand Up @@ -119,18 +121,19 @@ TEST(ALGEBRA_PLUGIN, telescope_detector) {
free_track_parameters<transform3> test_track_x(pos, 0, mom, -1);

// navigators
navigator<decltype(z_tel_det1), inspector_t> navigator_z1(z_tel_det1);
navigator<decltype(z_tel_det2), inspector_t> navigator_z2(z_tel_det2);
navigator<decltype(x_tel_det), inspector_t> navigator_x(x_tel_det);
navigator<decltype(z_tel_det1), inspector_t> navigator_z1;
navigator<decltype(z_tel_det2), inspector_t> navigator_z2;
navigator<decltype(x_tel_det), inspector_t> navigator_x;
using navigation_state_t = decltype(navigator_z1)::state;
using stepping_state_t = rk_stepper_t::state;

// propagation states
prop_state<stepping_state_t, navigation_state_t> propgation_z1(
test_track_z1);
test_track_z1, z_tel_det1);
prop_state<stepping_state_t, navigation_state_t> propgation_z2(
test_track_z2);
prop_state<stepping_state_t, navigation_state_t> propgation_x(test_track_x);
test_track_z2, z_tel_det2);
prop_state<stepping_state_t, navigation_state_t> propgation_x(test_track_x,
x_tel_det);

stepping_state_t &stepping_z1 = propgation_z1._stepping;
stepping_state_t &stepping_z2 = propgation_z2._stepping;
Expand Down Expand Up @@ -201,10 +204,10 @@ TEST(ALGEBRA_PLUGIN, telescope_detector) {
host_mr, n_surfaces, tel_length, pilot_track, rk_stepper_z);

// make at least sure it is navigatable
navigator<decltype(tel_detector), inspector_t> tel_navigator(tel_detector);
navigator<decltype(tel_detector), inspector_t> tel_navigator;

prop_state<stepping_state_t, navigation_state_t> tel_propagation(
pilot_track);
pilot_track, tel_detector);
navigation_state_t &tel_navigation = tel_propagation._navigation;

// run propagation
Expand Down
5 changes: 2 additions & 3 deletions tests/common/include/tests/common/tools_guided_navigator.inl
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,8 @@ TEST(ALGEBRA_PLUGIN, guided_navigator) {
pathlimit_aborter::state pathlimit{200. * unit_constants::cm};

// Propagator
propagator_t p(runge_kutta_stepper{b_field},
guided_navigator{telescope_det});
propagator_t::state guided_state(track, std::tie(pathlimit));
propagator_t p(runge_kutta_stepper{b_field}, guided_navigator{});
propagator_t::state guided_state(track, telescope_det, std::tie(pathlimit));

// Propagate
p.propagate(guided_state);
Expand Down
4 changes: 2 additions & 2 deletions tests/common/include/tests/common/tools_navigator.inl
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,10 @@ TEST(ALGEBRA_PLUGIN, navigator) {
free_track_parameters<transform3> traj(pos, 0, mom, -1);

stepper_t stepper;
navigator_t nav(toy_det);
navigator_t nav;

prop_state<stepper_t::state, navigator_t::state> propagation{
stepper_t::state{traj}, navigator_t::state{}};
stepper_t::state{traj}, navigator_t::state(toy_det, host_mr)};
navigator_t::state &navigation = propagation._navigation;
stepper_t::state &stepping = propagation._stepping;

Expand Down
10 changes: 5 additions & 5 deletions tests/common/include/tests/common/tools_propagator.inl
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ TEST(ALGEBRA_PLUGIN, propagator_line_stepper) {
const vector3 mom{1., 1., 0.};
free_track_parameters<transform3> track(pos, 0, mom, -1);

propagator_t p(stepper_t{}, navigator_t{d});
propagator_t p(stepper_t{}, navigator_t{});

propagator_t::state state(track);
propagator_t::state state(track, d);

EXPECT_TRUE(p.propagate(state))
<< state._navigation.inspector().to_string() << std::endl;
Expand Down Expand Up @@ -142,7 +142,7 @@ TEST_P(PropagatorWithRkStepper, propagator_rk_stepper) {
const b_field_t b_field(B);

// Propagator is built from the stepper and navigator
propagator_t p(stepper_t{b_field}, navigator_t{d});
propagator_t p(stepper_t{b_field}, navigator_t{});

// Iterate through uniformly distributed momentum directions
for (auto track :
Expand All @@ -167,8 +167,8 @@ TEST_P(PropagatorWithRkStepper, propagator_rk_stepper) {
helix_insp_state, lim_print_insp_state, pathlimit_aborter_state);

// Init propagator states
propagator_t::state state(track, actor_states);
propagator_t::state lim_state(lim_track, lim_actor_states);
propagator_t::state state(track, d, actor_states);
propagator_t::state lim_state(lim_track, d, lim_actor_states);

// Set step constraints
state._stepping.template set_constraint<step::constraint::e_accuracy>(
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/cuda/navigator_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ TEST(navigator_cuda, navigator) {
n_edc_layers);

// Create navigator
navigator_host_t nav(det);
navigator_host_t nav;

// Create the vector of initial track parameters
vecmem::vector<free_track_parameters<transform3>> tracks_host(&mng_mr);
Expand Down Expand Up @@ -63,7 +63,7 @@ TEST(navigator_cuda, navigator) {
stepper_t stepper;

prop_state<navigator_host_t::state> propagation{
stepper_t::state{track}, navigator_host_t::state{mng_mr}};
stepper_t::state{track}, navigator_host_t::state(det, mng_mr)};

navigator_host_t::state& navigation = propagation._navigation;
stepper_t::state& stepping = propagation._stepping;
Expand Down
5 changes: 3 additions & 2 deletions tests/unit_tests/cuda/navigator_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,14 @@ __global__ void navigator_test_kernel(
return;
}

navigator_device_t nav(det);
navigator_device_t nav;

auto& traj = tracks.at(gid);
stepper_t stepper;

prop_state<navigator_device_t::state> propagation{
stepper_t::state{traj}, navigator_device_t::state{candidates.at(gid)}};
stepper_t::state{traj},
navigator_device_t::state(det, candidates.at(gid))};

navigator_device_t::state& navigation = propagation._navigation;
stepper_t::state& stepping = propagation._stepping;
Expand Down
Loading