Skip to content

Commit

Permalink
refactor!: template the seed finder on the grid type (#2957)
Browse files Browse the repository at this point in the history
Make the seed finder templated on the grid type so that we can use any kind of grid on it
  • Loading branch information
CarloVarni authored Mar 1, 2024
1 parent abdddc6 commit f0294bd
Show file tree
Hide file tree
Showing 8 changed files with 38 additions and 28 deletions.
13 changes: 6 additions & 7 deletions Core/include/Acts/Seeding/SeedFinder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
#include "Acts/Seeding/SeedFilter.hpp"
#include "Acts/Seeding/SeedFinderConfig.hpp"
#include "Acts/Seeding/SeedFinderUtils.hpp"
#include "Acts/Seeding/SpacePointGrid.hpp"
#include "Acts/Utilities/detail/grid_helper.hpp"

#include <array>
#include <limits>
Expand All @@ -37,12 +35,12 @@ enum class SpacePointCandidateType : short { eBottom, eTop };

enum class DetectorMeasurementInfo : short { eDefault, eDetailed };

template <typename external_spacepoint_t, typename platform_t = void*>
template <typename external_spacepoint_t, typename grid_t,
typename platform_t = void*>
class SeedFinder {
///////////////////////////////////////////////////////////////////
// Public methods:
///////////////////////////////////////////////////////////////////
using grid_t = Acts::CylindricalSpacePointGrid<external_spacepoint_t>;

public:
struct SeedingState {
Expand Down Expand Up @@ -83,9 +81,10 @@ class SeedFinder {
/** @name Disallow default instantiation, copy, assignment */
//@{
SeedFinder() = default;
SeedFinder(const SeedFinder<external_spacepoint_t, platform_t>&) = delete;
SeedFinder<external_spacepoint_t, platform_t>& operator=(
const SeedFinder<external_spacepoint_t, platform_t>&) = default;
SeedFinder(const SeedFinder<external_spacepoint_t, grid_t, platform_t>&) =
delete;
SeedFinder<external_spacepoint_t, grid_t, platform_t>& operator=(
const SeedFinder<external_spacepoint_t, grid_t, platform_t>&) = default;
//@}

/// Create all seeds from the space points in the three iterators.
Expand Down
28 changes: 14 additions & 14 deletions Core/include/Acts/Seeding/SeedFinder.ipp
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@

namespace Acts {

template <typename external_spacepoint_t, typename platform_t>
SeedFinder<external_spacepoint_t, platform_t>::SeedFinder(
template <typename external_spacepoint_t, typename grid_t, typename platform_t>
SeedFinder<external_spacepoint_t, grid_t, platform_t>::SeedFinder(
const Acts::SeedFinderConfig<external_spacepoint_t>& config)
: m_config(config) {
if (!config.isInInternalUnits) {
Expand All @@ -35,9 +35,9 @@ SeedFinder<external_spacepoint_t, platform_t>::SeedFinder(
}
}

template <typename external_spacepoint_t, typename platform_t>
template <typename external_spacepoint_t, typename grid_t, typename platform_t>
template <template <typename...> typename container_t, typename sp_range_t>
void SeedFinder<external_spacepoint_t, platform_t>::createSeedsForGroup(
void SeedFinder<external_spacepoint_t, grid_t, platform_t>::createSeedsForGroup(
const Acts::SeedFinderOptions& options, SeedingState& state,
const grid_t& grid,
std::back_insert_iterator<container_t<Seed<external_spacepoint_t>>> outIt,
Expand Down Expand Up @@ -192,16 +192,15 @@ void SeedFinder<external_spacepoint_t, platform_t>::createSeedsForGroup(
} // loop on mediums
}

template <typename external_spacepoint_t, typename platform_t>
template <typename external_spacepoint_t, typename grid_t, typename platform_t>
template <Acts::SpacePointCandidateType candidateType, typename out_range_t>
inline void
SeedFinder<external_spacepoint_t, platform_t>::getCompatibleDoublets(
SeedFinder<external_spacepoint_t, grid_t, platform_t>::getCompatibleDoublets(
Acts::SpacePointData& spacePointData,
const Acts::SeedFinderOptions& options, const grid_t& grid,
boost::container::small_vector<
Acts::Neighbour<
typename SeedFinder<external_spacepoint_t, platform_t>::grid_t>,
Acts::detail::ipow(3, grid_t::DIM)>& otherSPsNeighbours,
boost::container::small_vector<Acts::Neighbour<grid_t>,
Acts::detail::ipow(3, grid_t::DIM)>&
otherSPsNeighbours,
const InternalSpacePoint<external_spacepoint_t>& mediumSP,
std::vector<LinCircle>& linCircleVec, out_range_t& outVec,
const float deltaRMinSP, const float deltaRMaxSP, const float uIP,
Expand Down Expand Up @@ -456,9 +455,10 @@ SeedFinder<external_spacepoint_t, platform_t>::getCompatibleDoublets(
}
}

template <typename external_spacepoint_t, typename platform_t>
template <typename external_spacepoint_t, typename grid_t, typename platform_t>
template <Acts::DetectorMeasurementInfo detailedMeasurement>
inline void SeedFinder<external_spacepoint_t, platform_t>::filterCandidates(
inline void
SeedFinder<external_spacepoint_t, grid_t, platform_t>::filterCandidates(
Acts::SpacePointData& spacePointData,
const InternalSpacePoint<external_spacepoint_t>& spM,
const Acts::SeedFinderOptions& options, SeedFilterState& seedFilterState,
Expand Down Expand Up @@ -817,10 +817,10 @@ inline void SeedFinder<external_spacepoint_t, platform_t>::filterCandidates(
} // loop on bottoms
}

template <typename external_spacepoint_t, typename platform_t>
template <typename external_spacepoint_t, typename grid_t, typename platform_t>
template <typename sp_range_t>
std::vector<Seed<external_spacepoint_t>>
SeedFinder<external_spacepoint_t, platform_t>::createSeedsForGroup(
SeedFinder<external_spacepoint_t, grid_t, platform_t>::createSeedsForGroup(
const Acts::SeedFinderOptions& options, const grid_t& grid,
const sp_range_t& bottomSPs, const std::size_t middleSPs,
const sp_range_t& topSPs) const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ class SeedingAlgorithm final : public IAlgorithm {
const Config& config() const { return m_cfg; }

private:
Acts::SeedFinder<SimSpacePoint> m_seedFinder;
Acts::SeedFinder<SimSpacePoint,
Acts::CylindricalSpacePointGrid<SimSpacePoint>>
m_seedFinder;
std::unique_ptr<const Acts::GridBinFinder<2ul>> m_bottomBinFinder;
std::unique_ptr<const Acts::GridBinFinder<2ul>> m_topBinFinder;

Expand Down
5 changes: 4 additions & 1 deletion Examples/Algorithms/TrackFinding/src/SeedingAlgorithm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,10 @@ ActsExamples::SeedingAlgorithm::SeedingAlgorithm(

m_cfg.seedFinderConfig.seedFilter =
std::make_unique<Acts::SeedFilter<SimSpacePoint>>(m_cfg.seedFilterConfig);
m_seedFinder = Acts::SeedFinder<SimSpacePoint>(m_cfg.seedFinderConfig);
m_seedFinder =
Acts::SeedFinder<SimSpacePoint,
Acts::CylindricalSpacePointGrid<SimSpacePoint>>(
m_cfg.seedFinderConfig);
}

ActsExamples::ProcessCode ActsExamples::SeedingAlgorithm::execute(
Expand Down
6 changes: 4 additions & 2 deletions Tests/UnitTests/Core/Seeding/SeedFinderTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,10 @@ int main(int argc, char** argv) {
Acts::ATLASCuts<SpacePoint> atlasCuts = Acts::ATLASCuts<SpacePoint>();
config.seedFilter = std::make_unique<Acts::SeedFilter<SpacePoint>>(
Acts::SeedFilter<SpacePoint>(sfconf, &atlasCuts));
Acts::SeedFinder<SpacePoint> a; // test creation of unconfigured finder
a = Acts::SeedFinder<SpacePoint>(config);
Acts::SeedFinder<SpacePoint, Acts::CylindricalSpacePointGrid<SpacePoint>>
a; // test creation of unconfigured finder
a = Acts::SeedFinder<SpacePoint, Acts::CylindricalSpacePointGrid<SpacePoint>>(
config);

// covariance tool, sets covariances per spacepoint as required
auto ct = [=](const SpacePoint& sp, float, float, float) {
Expand Down
3 changes: 2 additions & 1 deletion Tests/UnitTests/Plugins/Cuda/Seeding/SeedFinderCudaTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,8 @@ int main(int argc, char** argv) {
Acts::ATLASCuts<SpacePoint> atlasCuts = Acts::ATLASCuts<SpacePoint>();
config.seedFilter = std::make_unique<Acts::SeedFilter<SpacePoint>>(
Acts::SeedFilter<SpacePoint>(sfconf, &atlasCuts));
Acts::SeedFinder<SpacePoint> seedFinder_cpu(config);
Acts::SeedFinder<SpacePoint, Acts::CylindricalSpacePointGrid<SpacePoint>>
seedFinder_cpu(config);
Acts::SeedFinder<SpacePoint, Acts::Cuda> seedFinder_cuda(config, options);

// covariance tool, sets covariances per spacepoint as required
Expand Down
4 changes: 3 additions & 1 deletion Tests/UnitTests/Plugins/Cuda/Seeding2/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,9 @@ int main(int argc, char* argv[]) {
auto deviceCuts = testDeviceCuts();

// Set up the seedFinder objects.
Acts::SeedFinder<TestSpacePoint> seedFinder_host(sfConfig);
Acts::SeedFinder<TestSpacePoint,
Acts::CylindricalSpacePointGrid<TestSpacePoint>>
seedFinder_host(sfConfig);
Acts::Cuda::SeedFinder<TestSpacePoint> seedFinder_device(
sfConfig, sfOptions, filterConfig, deviceCuts, cmdl.cudaDevice);

Expand Down
3 changes: 2 additions & 1 deletion Tests/UnitTests/Plugins/Sycl/Seeding/SeedFinderSyclTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,8 @@ auto main(int argc, char** argv) -> int {
vecmem::sycl::device_memory_resource device_resource(queue.getQueue());
Acts::Sycl::SeedFinder<SpacePoint> syclSeedFinder(
config, options, deviceAtlasCuts, queue, resource, &device_resource);
Acts::SeedFinder<SpacePoint> normalSeedFinder(config);
Acts::SeedFinder<SpacePoint, Acts::CylindricalSpacePointGrid<SpacePoint>>
normalSeedFinder(config);
auto globalTool = [=](const SpacePoint& sp, float /*unused*/,
float /*unused*/, float /*unused*/)
-> std::tuple<Acts::Vector3, Acts::Vector2, std::optional<float>> {
Expand Down

0 comments on commit f0294bd

Please sign in to comment.