Skip to content

Commit

Permalink
Call device functions consistently from SYCL.
Browse files Browse the repository at this point in the history
  • Loading branch information
krasznaa committed Jan 8, 2025
1 parent 36e94c5 commit ced855e
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 31 deletions.
7 changes: 4 additions & 3 deletions device/sycl/src/fitting/fit_tracks.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/** TRACCC library, part of the ACTS project (R&D line)
*
* (c) 2022-2024 CERN for the benefit of the ACTS project
* (c) 2022-2025 CERN for the benefit of the ACTS project
*
* Mozilla Public License Version 2.0
*/
Expand All @@ -9,6 +9,7 @@

// Local include(s).
#include "../utils/calculate1DimNdRange.hpp"
#include "../utils/global_index.hpp"

// Project include(s).
#include "traccc/edm/device/sort_key.hpp"
Expand Down Expand Up @@ -96,7 +97,7 @@ track_state_container_types::buffer fit_tracks(
[track_candidates_view, keys_view = vecmem::get_data(keys_buffer),
param_ids_view =
vecmem::get_data(param_ids_buffer)](::sycl::nd_item<1> item) {
device::fill_sort_keys(item.get_global_linear_id(),
device::fill_sort_keys(details::global_index(item),
track_candidates_view, keys_view,
param_ids_view);
});
Expand All @@ -120,7 +121,7 @@ track_state_container_types::buffer fit_tracks(
range, [det_view, field_view, config, track_candidates_view,
param_ids_view = vecmem::get_data(param_ids_buffer),
track_states_view](::sycl::nd_item<1> item) {
device::fit<fitter_t>(item.get_global_linear_id(), det_view,
device::fit<fitter_t>(details::global_index(item), det_view,
field_view, config,
track_candidates_view, param_ids_view,
track_states_view);
Expand Down
29 changes: 14 additions & 15 deletions device/sycl/src/seeding/seed_finding.sycl
Original file line number Diff line number Diff line change
@@ -1,19 +1,15 @@
/** TRACCC library, part of the ACTS project (R&D line)
*
* (c) 2021-2024 CERN for the benefit of the ACTS project
* (c) 2021-2025 CERN for the benefit of the ACTS project
*
* Mozilla Public License Version 2.0
*/

// System include(s).
#include <algorithm>

// SYCL library include(s).
#include "traccc/sycl/seeding/seed_finding.hpp"

// SYCL library include(s).
// Local include(s).
#include "../utils/calculate1DimNdRange.hpp"
#include "../utils/get_queue.hpp"
#include "../utils/global_index.hpp"
#include "traccc/sycl/seeding/seed_finding.hpp"
#include "traccc/sycl/utils/make_prefix_sum_buff.hpp"

// Project include(s).
Expand All @@ -34,6 +30,9 @@
// VecMem include(s).
#include <vecmem/utils/sycl/local_accessor.hpp>

// System include(s).
#include <algorithm>

namespace traccc::sycl {
namespace kernels {

Expand Down Expand Up @@ -123,7 +122,7 @@ seed_finding::output_type seed_finding::operator()(
[config = m_seedfinder_config, g2_view, sp_grid_prefix_sum_view,
doublet_counter_view,
aux_globalCounter](::sycl::nd_item<1> item) {
device::count_doublets(item.get_global_linear_id(), config,
device::count_doublets(details::global_index(item), config,
g2_view, sp_grid_prefix_sum_view,
doublet_counter_view,
(*aux_globalCounter).m_nMidBot,
Expand Down Expand Up @@ -171,7 +170,7 @@ seed_finding::output_type seed_finding::operator()(
doubletFindRange,
[config = m_seedfinder_config, g2_view, doublet_counter_view,
mb_view, mt_view](::sycl::nd_item<1> item) {
device::find_doublets(item.get_global_linear_id(), config,
device::find_doublets(details::global_index(item), config,
g2_view, doublet_counter_view,
mb_view, mt_view);
});
Expand Down Expand Up @@ -210,7 +209,7 @@ seed_finding::output_type seed_finding::operator()(
mb_view, mt_view, triplet_counter_spM_view,
triplet_counter_midBot_view](::sycl::nd_item<1> item) {
device::count_triplets(
item.get_global_linear_id(), config, g2_view,
details::global_index(item), config, g2_view,
doublet_counter_view, mb_view, mt_view,
triplet_counter_spM_view, triplet_counter_midBot_view);
});
Expand All @@ -232,7 +231,7 @@ seed_finding::output_type seed_finding::operator()(
[doublet_counter_view, triplet_counter_spM_view,
aux_globalCounter](::sycl::nd_item<1> item) {
device::reduce_triplet_counts(
item.get_global_linear_id(), doublet_counter_view,
details::global_index(item), doublet_counter_view,
triplet_counter_spM_view,
(*aux_globalCounter).m_nTriplets);
});
Expand Down Expand Up @@ -270,7 +269,7 @@ seed_finding::output_type seed_finding::operator()(
triplet_counter_midBot_view,
triplet_view](::sycl::nd_item<1> item) {
device::find_triplets(
item.get_global_linear_id(), config, filter_config,
details::global_index(item), config, filter_config,
g2_view, doublet_counter_view, mt_view,
triplet_counter_spM_view, triplet_counter_midBot_view,
triplet_view);
Expand Down Expand Up @@ -311,7 +310,7 @@ seed_finding::output_type seed_finding::operator()(
filter_config.compatSeedLimit];

device::update_triplet_weights(
item.get_global_linear_id(), filter_config, g2_view,
details::global_index(item), filter_config, g2_view,
triplet_counter_spM_view, triplet_counter_midBot_view,
dataPos, triplet_view);
});
Expand Down Expand Up @@ -359,7 +358,7 @@ seed_finding::output_type seed_finding::operator()(
&local_mem[item.get_local_id() *
filter_config.max_triplets_per_spM];

device::select_seeds(item.get_global_linear_id(),
device::select_seeds(details::global_index(item),
filter_config, spacepoints_view,
g2_view, triplet_counter_spM_view,
triplet_counter_midBot_view,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/** TRACCC library, part of the ACTS project (R&D line)
*
* (c) 2023-2024 CERN for the benefit of the ACTS project
* (c) 2023-2025 CERN for the benefit of the ACTS project
*
* Mozilla Public License Version 2.0
*/
Expand All @@ -10,6 +10,7 @@
// Local include(s).
#include "../utils/calculate1DimNdRange.hpp"
#include "../utils/get_queue.hpp"
#include "../utils/global_index.hpp"

// Project include(s).
#include "traccc/edm/measurement.hpp"
Expand Down Expand Up @@ -69,7 +70,7 @@ spacepoint_collection_types::buffer silicon_pixel_spacepoint_formation(
spacepoints_view = vecmem::get_data(result)](
::sycl::nd_item<1> item) {
device::form_spacepoints<detector_t>(
item.get_global_linear_id(), det_view,
details::global_index(item), det_view,
measurements_view, n_measurements, spacepoints_view);
});
})
Expand Down
14 changes: 6 additions & 8 deletions device/sycl/src/seeding/spacepoint_binning.sycl
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
/** TRACCC library, part of the ACTS project (R&D line)
*
* (c) 2021-2024 CERN for the benefit of the ACTS project
* (c) 2021-2025 CERN for the benefit of the ACTS project
*
* Mozilla Public License Version 2.0
*/

// Local include(s).
#include "../utils/calculate1DimNdRange.hpp"
#include "traccc/sycl/seeding/spacepoint_binning.hpp"

// Local include(s).
#include "../utils/get_queue.hpp"
#include "../utils/global_index.hpp"
#include "traccc/sycl/seeding/spacepoint_binning.hpp"

// Project include(s).
#include "traccc/seeding/device/count_grid_capacities.hpp"
Expand Down Expand Up @@ -72,7 +71,7 @@ sp_grid_buffer spacepoint_binning::operator()(
z_axis = m_axes.second, spacepoints = spacepoints_view,
grid_capacities =
grid_capacities_view](::sycl::nd_item<1> item) {
device::count_grid_capacities(item.get_global_linear_id(),
device::count_grid_capacities(details::global_index(item),
config, phi_axis, z_axis,
spacepoints, grid_capacities);
});
Expand All @@ -99,9 +98,8 @@ sp_grid_buffer spacepoint_binning::operator()(
h.parallel_for<kernels::populate_grid>(
range, [config = m_config, spacepoints = spacepoints_view,
grid = grid_view](::sycl::nd_item<1> item) {
device::populate_grid(
static_cast<unsigned int>(item.get_global_linear_id()),
config, spacepoints, grid);
device::populate_grid(details::global_index(item), config,
spacepoints, grid);
});
})
.wait_and_throw();
Expand Down
7 changes: 4 additions & 3 deletions device/sycl/src/seeding/track_params_estimation.sycl
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
/** TRACCC library, part of the ACTS project (R&D line)
*
* (c) 2021-2024 CERN for the benefit of the ACTS project
* (c) 2021-2025 CERN for the benefit of the ACTS project
*
* Mozilla Public License Version 2.0
*/

// SYCL library include(s).
// Local include(s).
#include "../utils/calculate1DimNdRange.hpp"
#include "../utils/get_queue.hpp"
#include "../utils/global_index.hpp"
#include "traccc/sycl/seeding/track_params_estimation.hpp"

// Project include(s).
Expand Down Expand Up @@ -63,7 +64,7 @@ track_params_estimation::output_type track_params_estimation::operator()(
trackParamsNdRange,
[spacepoints_view, seeds_view, bfield, stddev,
params_view](::sycl::nd_item<1> item) {
device::estimate_track_params(item.get_global_linear_id(),
device::estimate_track_params(details::global_index(item),
spacepoints_view, seeds_view,
bfield, stddev, params_view);
});
Expand Down
36 changes: 36 additions & 0 deletions device/sycl/src/utils/global_index.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/** TRACCC library, part of the ACTS project (R&D line)
*
* (c) 2025 CERN for the benefit of the ACTS project
*
* Mozilla Public License Version 2.0
*/

#pragma once

// Project include(s).
#include "traccc/device/global_index.hpp"

// SYCL include(s).
#include <sycl/sycl.hpp>

namespace traccc::sycl::details {

/// Function creating a global index in a 1D SYCL kernel
inline device::global_index_t global_index(const ::sycl::nd_item<1>& item) {

return static_cast<device::global_index_t>(item.get_global_linear_id());
}

/// Function creating a global index in a 2D SYCL kernel
inline device::global_index_t global_index(const ::sycl::nd_item<2>& item) {

return static_cast<device::global_index_t>(item.get_global_linear_id());
}

/// Function creating a global index in a 3D SYCL kernel
inline device::global_index_t global_index(const ::sycl::nd_item<3>& item) {

return static_cast<device::global_index_t>(item.get_global_linear_id());
}

} // namespace traccc::sycl::details

0 comments on commit ced855e

Please sign in to comment.