Skip to content

Commit

Permalink
Make sure that all SYCL kernels would have a unique kernel class.
Browse files Browse the repository at this point in the history
Trying to avoid confusion at runtime about which kernel is which.
  • Loading branch information
krasznaa committed Nov 18, 2024
1 parent 268a632 commit ed91591
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,28 @@
#include <detray/propagator/rk_stepper.hpp>

namespace traccc::sycl {
namespace kernels::combinatorial_kalman_filter_constant_field_default_detector {

struct make_barcode_sequence;
struct apply_interaction;
struct find_tracks;
struct fill_sort_keys;
struct propagate_to_next_surface;
struct build_tracks;
struct prune_tracks;

struct kernels {
using make_barcode_sequence_kernel_type = make_barcode_sequence;
using apply_interaction_kernel_type = apply_interaction;
using find_tracks_kernel_type = find_tracks;
using fill_sort_keys_kernel_type = fill_sort_keys;
using propagate_to_next_surface_kernel_type = propagate_to_next_surface;
using build_tracks_kernel_type = build_tracks;
using prune_tracks_kernel_type = prune_tracks;
}; // namespace kernels

} // namespace
// kernels::combinatorial_kalman_filter_constant_field_default_detector

combinatorial_kalman_filter_algorithm::output_type
combinatorial_kalman_filter_algorithm::operator()(
Expand All @@ -30,9 +52,10 @@ combinatorial_kalman_filter_algorithm::operator()(
detray::rk_stepper<detray::bfield::const_field_t::view_t,
default_detector::device::algebra_type,
detray::constrained_step<>>,
detray::navigator<const default_detector::device>>(
det, field, measurements, seeds, m_config, m_mr, m_copy,
details::get_queue(m_queue));
detray::navigator<const default_detector::device>,
kernels::combinatorial_kalman_filter_constant_field_default_detector::
kernels>(det, field, measurements, seeds, m_config, m_mr, m_copy,
details::get_queue(m_queue));
}

} // namespace traccc::sycl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,29 @@
#include <detray/propagator/rk_stepper.hpp>

namespace traccc::sycl {
namespace kernels::
combinatorial_kalman_filter_constant_field_telescope_detector {

struct make_barcode_sequence;
struct apply_interaction;
struct find_tracks;
struct fill_sort_keys;
struct propagate_to_next_surface;
struct build_tracks;
struct prune_tracks;

struct kernels {
using make_barcode_sequence_kernel_type = make_barcode_sequence;
using apply_interaction_kernel_type = apply_interaction;
using find_tracks_kernel_type = find_tracks;
using fill_sort_keys_kernel_type = fill_sort_keys;
using propagate_to_next_surface_kernel_type = propagate_to_next_surface;
using build_tracks_kernel_type = build_tracks;
using prune_tracks_kernel_type = prune_tracks;
}; // namespace kernels

} // namespace
// kernels::combinatorial_kalman_filter_constant_field_telescope_detector

combinatorial_kalman_filter_algorithm::output_type
combinatorial_kalman_filter_algorithm::operator()(
Expand All @@ -30,9 +53,10 @@ combinatorial_kalman_filter_algorithm::operator()(
detray::rk_stepper<detray::bfield::const_field_t::view_t,
telescope_detector::device::algebra_type,
detray::constrained_step<>>,
detray::navigator<const telescope_detector::device>>(
det, field, measurements, seeds, m_config, m_mr, m_copy,
details::get_queue(m_queue));
detray::navigator<const telescope_detector::device>,
kernels::combinatorial_kalman_filter_constant_field_telescope_detector::
kernels>(det, field, measurements, seeds, m_config, m_mr, m_copy,
details::get_queue(m_queue));
}

} // namespace traccc::sycl
21 changes: 13 additions & 8 deletions device/sycl/src/finding/find_tracks.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ namespace traccc::sycl::details {
///
/// @tparam stepper_t The stepper type used for the track propagation
/// @tparam navigator_t The navigator type used for the track navigation
/// @tparam kernels_t Structure with unique "kernel structures"
///
/// @param det A view of the detector object
/// @param field The magnetic field object
Expand All @@ -72,7 +73,7 @@ namespace traccc::sycl::details {
///
/// @return A buffer of the found track candidates
///
template <typename stepper_t, typename navigator_t>
template <typename stepper_t, typename navigator_t, typename kernels_t>
track_candidate_container_types::buffer find_tracks(
const typename navigator_t::detector_type::view_type& det,
const typename stepper_t::magnetic_field_type& field,
Expand Down Expand Up @@ -129,7 +130,8 @@ track_candidate_container_types::buffer find_tracks(

queue
.submit([&](::sycl::handler& h) {
h.parallel_for(
h.parallel_for<
typename kernels_t::make_barcode_sequence_kernel_type>(
calculate1DimNdRange(n_modules, 64),
[uniques_view = vecmem::get_data(uniques_buffer),
barcodes_view = vecmem::get_data(barcodes_buffer)](
Expand Down Expand Up @@ -182,7 +184,8 @@ track_candidate_container_types::buffer find_tracks(

queue
.submit([&](::sycl::handler& h) {
h.parallel_for(
h.parallel_for<
typename kernels_t::apply_interaction_kernel_type>(
calculate1DimNdRange(n_in_params, 64),
[config, det, n_in_params,
in_params = vecmem::get_data(in_params_buffer),
Expand Down Expand Up @@ -244,7 +247,7 @@ track_candidate_container_types::buffer find_tracks(
shared_candidates_size(1, h);

// Launch the kernel.
h.parallel_for(
h.parallel_for<typename kernels_t::find_tracks_kernel_type>(
calculate1DimNdRange(n_in_params, nFindTracksThreads),
[config, det, measurements,
in_params = vecmem::get_data(in_params_buffer),
Expand Down Expand Up @@ -308,7 +311,8 @@ track_candidate_container_types::buffer find_tracks(

queue
.submit([&](::sycl::handler& h) {
h.parallel_for(
h.parallel_for<
typename kernels_t::fill_sort_keys_kernel_type>(
calculate1DimNdRange(n_candidates, 256),
[in_params = vecmem::get_data(in_params_buffer),
keys = vecmem::get_data(keys_buffer),
Expand Down Expand Up @@ -356,7 +360,8 @@ track_candidate_container_types::buffer find_tracks(
// surface.
queue
.submit([&](::sycl::handler& h) {
h.parallel_for(
h.parallel_for<typename kernels_t::
propagate_to_next_surface_kernel_type>(
calculate1DimNdRange(n_candidates, 64),
[config, det, field,
in_params = vecmem::get_data(in_params_buffer),
Expand Down Expand Up @@ -440,7 +445,7 @@ track_candidate_container_types::buffer find_tracks(

queue
.submit([&](::sycl::handler& h) {
h.parallel_for(
h.parallel_for<typename kernels_t::build_tracks_kernel_type>(
calculate1DimNdRange(n_tips_total, 64),
[config, measurements, seeds,
links = vecmem::get_data(links_buffer),
Expand Down Expand Up @@ -478,7 +483,7 @@ track_candidate_container_types::buffer find_tracks(

queue
.submit([&](::sycl::handler& h) {
h.parallel_for(
h.parallel_for<typename kernels_t::prune_tracks_kernel_type>(
calculate1DimNdRange(n_valid_tracks, 64),
[track_candidates,
valid_indices = vecmem::get_data(valid_indices_buffer),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ namespace traccc::sycl::details {
/// functions
///
/// @tparam detector_t The detector type to use
/// @tparam kernel_t The kernel type to use
///
/// @param det_view The view of the detector to use
/// @param measurements_view The view of the measurements to process
Expand All @@ -36,7 +37,7 @@ namespace traccc::sycl::details {
/// @param queue The queue to use for the computation
/// @return A buffer of the created spacepoints
///
template <typename detector_t>
template <typename detector_t, typename kernel_t>
spacepoint_collection_types::buffer silicon_pixel_spacepoint_formation(
const typename detector_t::view_type& det_view,
const measurement_collection_types::const_view& measurements_view,
Expand Down Expand Up @@ -64,7 +65,7 @@ spacepoint_collection_types::buffer silicon_pixel_spacepoint_formation(
// Run the spacepoint formation on the device.
queue
.submit([&](cl::sycl::handler& h) {
h.parallel_for(
h.parallel_for<kernel_t>(
countRange, [det_view, measurements_view, n_measurements,
spacepoints_view = vecmem::get_data(result)](
cl::sycl::nd_item<1> item) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,20 @@
#include "traccc/sycl/seeding/silicon_pixel_spacepoint_formation_algorithm.hpp"

namespace traccc::sycl {
namespace kernels {

struct form_spacepoints_default_detector;

} // namespace kernels

silicon_pixel_spacepoint_formation_algorithm::output_type
silicon_pixel_spacepoint_formation_algorithm::operator()(
const default_detector::view& det,
const measurement_collection_types::const_view& meas) const {

return details::silicon_pixel_spacepoint_formation<
default_detector::device>(det, meas, m_mr.main, m_copy,
details::get_queue(m_queue));
default_detector::device, kernels::form_spacepoints_default_detector>(
det, meas, m_mr.main, m_copy, details::get_queue(m_queue));
}

} // namespace traccc::sycl
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,21 @@
#include "traccc/sycl/seeding/silicon_pixel_spacepoint_formation_algorithm.hpp"

namespace traccc::sycl {
namespace kernels {

struct form_spacepoints_telescope_detector;

} // namespace kernels

silicon_pixel_spacepoint_formation_algorithm::output_type
silicon_pixel_spacepoint_formation_algorithm::operator()(
const telescope_detector::view& det,
const measurement_collection_types::const_view& meas) const {

return details::silicon_pixel_spacepoint_formation<
telescope_detector::device>(det, meas, m_mr.main, m_copy,
details::get_queue(m_queue));
telescope_detector::device,
kernels::form_spacepoints_telescope_detector>(
det, meas, m_mr.main, m_copy, details::get_queue(m_queue));
}

} // namespace traccc::sycl

0 comments on commit ed91591

Please sign in to comment.