diff --git a/Core/include/Acts/Clusterization/Clusterization.hpp b/Core/include/Acts/Clusterization/Clusterization.hpp index fc40bd4b5a6..839cc3d4daa 100644 --- a/Core/include/Acts/Clusterization/Clusterization.hpp +++ b/Core/include/Acts/Clusterization/Clusterization.hpp @@ -13,6 +13,26 @@ namespace Acts::Ccl { +template +concept HasRetrievableColumnInfo = requires(Cell cell) { + { getCellColumn(cell) } -> std::same_as; +}; + +template +concept HasRetrievableRowInfo = requires(Cell cell) { + { getCellRow(cell) } -> std::same_as; +}; + +template +concept HasRetrievableLabelInfo = requires(Cell cell) { + { getCellLabel(cell) } -> std::same_as; +}; + +template +concept CanAcceptCell = requires(Cell cell, Cluster cluster) { + { clusterAddCell(cluster, cell) } -> std::same_as; +}; + using Label = int; constexpr Label NO_LABEL = 0; @@ -20,7 +40,7 @@ constexpr Label NO_LABEL = 0; // always loops backward, starting from the reference cell. Since // the cells are globally sorted column-wise, the connection function // can therefore tell when the search should be stopped. -enum class ConnectResult { +enum class ConnectResult : std::uint8_t { eNoConn, // No connections, keep looking eNoConnStop, // No connections, stop looking eConn // Found connection @@ -28,6 +48,8 @@ enum class ConnectResult { // Default connection type for 2-D grids: 4- or 8-cell connectivity template + requires(Acts::Ccl::HasRetrievableColumnInfo && + Acts::Ccl::HasRetrievableRowInfo) struct Connect2D { bool conn8{true}; Connect2D() = default; @@ -36,7 +58,7 @@ struct Connect2D { }; // Default connection type for 1-D grids: 2-cell connectivity -template +template struct Connect1D { virtual ConnectResult operator()(const Cell& ref, const Cell& iter) const; }; @@ -48,15 +70,15 @@ struct DefaultConnect { "Only grid dimensions of 1 or 2 are supported"); }; +template +struct DefaultConnect : public Connect1D {}; + template struct DefaultConnect : public Connect2D { explicit DefaultConnect(bool commonCorner) : Connect2D(commonCorner) {} - DefaultConnect() : DefaultConnect(true) {} + DefaultConnect() = default; }; -template -struct DefaultConnect : public Connect1D {}; - /// @brief labelClusters /// /// In-place connected component labelling using the Hoshen-Kopelman algorithm. @@ -70,6 +92,8 @@ struct DefaultConnect : public Connect1D {}; template > + requires( + Acts::Ccl::HasRetrievableLabelInfo) void labelClusters(CellCollection& cells, Connect connect = Connect()); /// @brief mergeClusters @@ -82,6 +106,9 @@ void labelClusters(CellCollection& cells, Connect connect = Connect()); /// @return nothing template + requires(GridDim == 1 || GridDim == 2) && + Acts::Ccl::HasRetrievableLabelInfo< + typename CellCollection::value_type> ClusterCollection mergeClusters(CellCollection& /*cells*/); /// @brief createClusters diff --git a/Core/include/Acts/Clusterization/Clusterization.ipp b/Core/include/Acts/Clusterization/Clusterization.ipp index 0aa79b53981..90a69930ce1 100644 --- a/Core/include/Acts/Clusterization/Clusterization.ipp +++ b/Core/include/Acts/Clusterization/Clusterization.ipp @@ -14,60 +14,6 @@ namespace Acts::Ccl::internal { -// Machinery for validating generic Cell/Cluster types at compile-time - -template -struct cellTypeHasRequiredFunctions : std::false_type {}; - -template -struct cellTypeHasRequiredFunctions< - T, 2, - std::void_t())), - decltype(getCellColumn(std::declval())), - decltype(getCellLabel(std::declval()))>> : std::true_type { -}; - -template -struct cellTypeHasRequiredFunctions< - T, 1, - std::void_t())), - decltype(getCellLabel(std::declval()))>> : std::true_type { -}; - -template -struct clusterTypeHasRequiredFunctions : std::false_type {}; - -template -struct clusterTypeHasRequiredFunctions< - T, U, - std::void_t(), std::declval()))>> - : std::true_type {}; - -template -constexpr void staticCheckGridDim() { - static_assert( - GridDim == 1 || GridDim == 2, - "mergeClusters is only defined for grid dimensions of 1 or 2. "); -} - -template -constexpr void staticCheckCellType() { - constexpr bool hasFns = cellTypeHasRequiredFunctions(); - static_assert(hasFns, - "Cell type should have the following functions: " - "'int getCellRow(const Cell&)', " - "'int getCellColumn(const Cell&)', " - "'Label& getCellLabel(Cell&)'"); -} - -template -constexpr void staticCheckClusterType() { - constexpr bool hasFns = clusterTypeHasRequiredFunctions(); - static_assert(hasFns, - "Cluster type should have the following function: " - "'void clusterAddCell(Cluster&, const Cell&)'"); -} - template struct Compare { static_assert(GridDim != 1 && GridDim != 2, @@ -75,25 +21,27 @@ struct Compare { }; // Comparator function object for cells, column-wise ordering -// Specialization for 2-D grid -template -struct Compare { +// Specialization for 1-D grids +template +struct Compare { bool operator()(const Cell& c0, const Cell& c1) const { - int row0 = getCellRow(c0); - int row1 = getCellRow(c1); int col0 = getCellColumn(c0); int col1 = getCellColumn(c1); - return (col0 == col1) ? row0 < row1 : col0 < col1; + return col0 < col1; } }; -// Specialization for 1-D grids +// Specialization for 2-D grid template -struct Compare { + requires(Acts::Ccl::HasRetrievableColumnInfo && + Acts::Ccl::HasRetrievableRowInfo) +struct Compare { bool operator()(const Cell& c0, const Cell& c1) const { + int row0 = getCellRow(c0); + int row1 = getCellRow(c1); int col0 = getCellColumn(c0); int col1 = getCellColumn(c1); - return col0 < col1; + return (col0 == col1) ? row0 < row1 : col0 < col1; } }; @@ -184,6 +132,10 @@ Connections getConnections(typename std::vector::iterator it, } template + requires( + Acts::Ccl::HasRetrievableLabelInfo && + Acts::Ccl::CanAcceptCell) ClusterCollection mergeClustersImpl(CellCollection& cells) { using Cluster = typename ClusterCollection::value_type; @@ -215,6 +167,8 @@ ClusterCollection mergeClustersImpl(CellCollection& cells) { namespace Acts::Ccl { template + requires(Acts::Ccl::HasRetrievableColumnInfo && + Acts::Ccl::HasRetrievableRowInfo) ConnectResult Connect2D::operator()(const Cell& ref, const Cell& iter) const { int deltaRow = std::abs(getCellRow(ref) - getCellRow(iter)); @@ -237,7 +191,7 @@ ConnectResult Connect2D::operator()(const Cell& ref, return ConnectResult::eNoConn; } -template +template ConnectResult Connect1D::operator()(const Cell& ref, const Cell& iter) const { int deltaCol = std::abs(getCellColumn(ref) - getCellColumn(iter)); @@ -267,9 +221,10 @@ void recordEquivalences(const internal::Connections seen, } template + requires( + Acts::Ccl::HasRetrievableLabelInfo) void labelClusters(CellCollection& cells, Connect connect) { using Cell = typename CellCollection::value_type; - internal::staticCheckCellType(); internal::DisjointSets ds{}; @@ -277,7 +232,8 @@ void labelClusters(CellCollection& cells, Connect connect) { std::ranges::sort(cells, internal::Compare()); // First pass: Allocate labels and record equivalences - for (auto it = cells.begin(); it != cells.end(); ++it) { + for (auto it = std::ranges::begin(cells); it != std::ranges::end(cells); + ++it) { const internal::Connections seen = internal::getConnections(it, cells, connect); if (seen.nconn == 0) { @@ -299,13 +255,11 @@ void labelClusters(CellCollection& cells, Connect connect) { template + requires(GridDim == 1 || GridDim == 2) && + Acts::Ccl::HasRetrievableLabelInfo< + typename CellCollection::value_type> ClusterCollection mergeClusters(CellCollection& cells) { using Cell = typename CellCollection::value_type; - using Cluster = typename ClusterCollection::value_type; - internal::staticCheckGridDim(); - internal::staticCheckCellType(); - internal::staticCheckClusterType(); - if constexpr (GridDim > 1) { // Sort the cells by their cluster label, only needed if more than // one spatial dimension @@ -318,10 +272,6 @@ ClusterCollection mergeClusters(CellCollection& cells) { template ClusterCollection createClusters(CellCollection& cells, Connect connect) { - using Cell = typename CellCollection::value_type; - using Cluster = typename ClusterCollection::value_type; - internal::staticCheckCellType(); - internal::staticCheckClusterType(); labelClusters(cells, connect); return mergeClusters(cells); } diff --git a/Core/include/Acts/Clusterization/TimedClusterization.hpp b/Core/include/Acts/Clusterization/TimedClusterization.hpp index 0e55afa2843..b3e77f75604 100644 --- a/Core/include/Acts/Clusterization/TimedClusterization.hpp +++ b/Core/include/Acts/Clusterization/TimedClusterization.hpp @@ -1,6 +1,6 @@ // This file is part of the ACTS project. // -// Copyright (C) 2016-2024 CERN for the benefit of the ACTS project +// Copyright (C) 2016 CERN for the benefit of the ACTS project // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this diff --git a/Tests/UnitTests/Core/Clusterization/TimedClusterizationTests.cpp b/Tests/UnitTests/Core/Clusterization/TimedClusterizationTests.cpp index bfefce7f347..414569b41cb 100644 --- a/Tests/UnitTests/Core/Clusterization/TimedClusterizationTests.cpp +++ b/Tests/UnitTests/Core/Clusterization/TimedClusterizationTests.cpp @@ -53,6 +53,51 @@ static void clusterAddCell(Cluster& cl, const Cell& cell) { cl.ids.push_back(cell.id); } +BOOST_AUTO_TEST_CASE(TimedGrid_1D_withtime) { + // 1x10 matrix + /* + X X X Y O X Y Y X X + */ + // 6 + 3 cells -> 3 + 2 clusters in total + + std::vector cells; + // X + cells.emplace_back(0ul, 0, -1, 0); + cells.emplace_back(1ul, 1, -1, 0); + cells.emplace_back(2ul, 2, -1, 0); + cells.emplace_back(3ul, 5, -1, 0); + cells.emplace_back(4ul, 8, -1, 0); + cells.emplace_back(5ul, 9, -1, 0); + // Y + cells.emplace_back(6ul, 3, 0, 1); + cells.emplace_back(7ul, 6, 1, 1); + cells.emplace_back(8ul, 7, 1, 1); + + std::vector> expectedResults; + expectedResults.push_back({0ul, 1ul, 2ul}); + expectedResults.push_back({6ul}); + expectedResults.push_back({3ul}); + expectedResults.push_back({7ul, 8ul}); + expectedResults.push_back({4ul, 5ul}); + + ClusterCollection clusters = + Acts::Ccl::createClusters( + cells, Acts::Ccl::TimedConnect(0.5)); + + BOOST_CHECK_EQUAL(5ul, clusters.size()); + + for (std::size_t i(0); i < clusters.size(); ++i) { + std::vector& timedIds = clusters[i].ids; + const std::vector& expected = expectedResults[i]; + std::sort(timedIds.begin(), timedIds.end()); + BOOST_CHECK_EQUAL(timedIds.size(), expected.size()); + + for (std::size_t j(0); j < timedIds.size(); ++j) { + BOOST_CHECK_EQUAL(timedIds[j], expected[j]); + } + } +} + BOOST_AUTO_TEST_CASE(TimedGrid_2D_notime) { // 4x4 matrix /*