Skip to content

Commit

Permalink
rewrite some parts with concepts
Browse files Browse the repository at this point in the history
  • Loading branch information
cvarni committed Oct 2, 2024
1 parent 85bc5c6 commit dcdeac4
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 82 deletions.
39 changes: 33 additions & 6 deletions Core/include/Acts/Clusterization/Clusterization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,43 @@

namespace Acts::Ccl {

template <typename Cell>
concept HasRetrievableColumnInfo = requires(Cell cell) {
{ getCellColumn(cell) } -> std::same_as<int>;
};

template <typename Cell>
concept HasRetrievableRowInfo = requires(Cell cell) {
{ getCellRow(cell) } -> std::same_as<int>;
};

template <typename Cell>
concept HasRetrievableLabelInfo = requires(Cell cell) {
{ getCellLabel(cell) } -> std::same_as<int&>;
};

template <typename Cell, typename Cluster>
concept CanAcceptCell = requires(Cell cell, Cluster cluster) {
{ clusterAddCell(cluster, cell) } -> std::same_as<void>;
};

using Label = int;
constexpr Label NO_LABEL = 0;

// When looking for a cell connected to a reference cluster, the code
// always loops backward, starting from the reference cell. Since
// the cells are globally sorted column-wise, the connection function
// can therefore tell when the search should be stopped.
enum class ConnectResult {
enum class ConnectResult : std::uint8_t {
eNoConn, // No connections, keep looking
eNoConnStop, // No connections, stop looking
eConn // Found connection
};

// Default connection type for 2-D grids: 4- or 8-cell connectivity
template <typename Cell>
requires(Acts::Ccl::HasRetrievableColumnInfo<Cell> &&
Acts::Ccl::HasRetrievableRowInfo<Cell>)
struct Connect2D {
bool conn8{true};
Connect2D() = default;
Expand All @@ -36,7 +58,7 @@ struct Connect2D {
};

// Default connection type for 1-D grids: 2-cell connectivity
template <typename Cell>
template <Acts::Ccl::HasRetrievableColumnInfo Cell>
struct Connect1D {
virtual ConnectResult operator()(const Cell& ref, const Cell& iter) const;
};
Expand All @@ -48,15 +70,15 @@ struct DefaultConnect {
"Only grid dimensions of 1 or 2 are supported");
};

template <typename Cell>
struct DefaultConnect<Cell, 1> : public Connect1D<Cell> {};

template <typename Cell>
struct DefaultConnect<Cell, 2> : public Connect2D<Cell> {
explicit DefaultConnect(bool commonCorner) : Connect2D<Cell>(commonCorner) {}
DefaultConnect() : DefaultConnect(true) {}
DefaultConnect() = default;
};

template <typename Cell>
struct DefaultConnect<Cell, 1> : public Connect1D<Cell> {};

/// @brief labelClusters
///
/// In-place connected component labelling using the Hoshen-Kopelman algorithm.
Expand All @@ -70,6 +92,8 @@ struct DefaultConnect<Cell, 1> : public Connect1D<Cell> {};
template <typename CellCollection, std::size_t GridDim = 2,
typename Connect =
DefaultConnect<typename CellCollection::value_type, GridDim>>
requires(
Acts::Ccl::HasRetrievableLabelInfo<typename CellCollection::value_type>)
void labelClusters(CellCollection& cells, Connect connect = Connect());

/// @brief mergeClusters
Expand All @@ -82,6 +106,9 @@ void labelClusters(CellCollection& cells, Connect connect = Connect());
/// @return nothing
template <typename CellCollection, typename ClusterCollection,
std::size_t GridDim>
requires(GridDim == 1 || GridDim == 2) &&
Acts::Ccl::HasRetrievableLabelInfo<
typename CellCollection::value_type>
ClusterCollection mergeClusters(CellCollection& /*cells*/);

/// @brief createClusters
Expand Down
100 changes: 25 additions & 75 deletions Core/include/Acts/Clusterization/Clusterization.ipp
Original file line number Diff line number Diff line change
Expand Up @@ -14,86 +14,34 @@

namespace Acts::Ccl::internal {

// Machinery for validating generic Cell/Cluster types at compile-time

template <typename, std::size_t, typename T = void>
struct cellTypeHasRequiredFunctions : std::false_type {};

template <typename T>
struct cellTypeHasRequiredFunctions<
T, 2,
std::void_t<decltype(getCellRow(std::declval<T>())),
decltype(getCellColumn(std::declval<T>())),
decltype(getCellLabel(std::declval<T&>()))>> : std::true_type {
};

template <typename T>
struct cellTypeHasRequiredFunctions<
T, 1,
std::void_t<decltype(getCellColumn(std::declval<T>())),
decltype(getCellLabel(std::declval<T&>()))>> : std::true_type {
};

template <typename, typename, typename T = void>
struct clusterTypeHasRequiredFunctions : std::false_type {};

template <typename T, typename U>
struct clusterTypeHasRequiredFunctions<
T, U,
std::void_t<decltype(clusterAddCell(std::declval<T>(), std::declval<U>()))>>
: std::true_type {};

template <std::size_t GridDim>
constexpr void staticCheckGridDim() {
static_assert(
GridDim == 1 || GridDim == 2,
"mergeClusters is only defined for grid dimensions of 1 or 2. ");
}

template <typename T, std::size_t GridDim>
constexpr void staticCheckCellType() {
constexpr bool hasFns = cellTypeHasRequiredFunctions<T, GridDim>();
static_assert(hasFns,
"Cell type should have the following functions: "
"'int getCellRow(const Cell&)', "
"'int getCellColumn(const Cell&)', "
"'Label& getCellLabel(Cell&)'");
}

template <typename T, typename U>
constexpr void staticCheckClusterType() {
constexpr bool hasFns = clusterTypeHasRequiredFunctions<T, U>();
static_assert(hasFns,
"Cluster type should have the following function: "
"'void clusterAddCell(Cluster&, const Cell&)'");
}

template <typename Cell, std::size_t GridDim>
struct Compare {
static_assert(GridDim != 1 && GridDim != 2,
"Only grid dimensions of 1 or 2 are supported");
};

// Comparator function object for cells, column-wise ordering
// Specialization for 2-D grid
template <typename Cell>
struct Compare<Cell, 2> {
// Specialization for 1-D grids
template <Acts::Ccl::HasRetrievableColumnInfo Cell>
struct Compare<Cell, 1> {
bool operator()(const Cell& c0, const Cell& c1) const {
int row0 = getCellRow(c0);
int row1 = getCellRow(c1);
int col0 = getCellColumn(c0);
int col1 = getCellColumn(c1);
return (col0 == col1) ? row0 < row1 : col0 < col1;
return col0 < col1;
}
};

// Specialization for 1-D grids
// Specialization for 2-D grid
template <typename Cell>
struct Compare<Cell, 1> {
requires(Acts::Ccl::HasRetrievableColumnInfo<Cell> &&
Acts::Ccl::HasRetrievableRowInfo<Cell>)
struct Compare<Cell, 2> {
bool operator()(const Cell& c0, const Cell& c1) const {
int row0 = getCellRow(c0);
int row1 = getCellRow(c1);
int col0 = getCellColumn(c0);
int col1 = getCellColumn(c1);
return col0 < col1;
return (col0 == col1) ? row0 < row1 : col0 < col1;
}
};

Expand Down Expand Up @@ -184,6 +132,10 @@ Connections<GridDim> getConnections(typename std::vector<Cell>::iterator it,
}

template <typename CellCollection, typename ClusterCollection>
requires(
Acts::Ccl::HasRetrievableLabelInfo<typename CellCollection::value_type> &&
Acts::Ccl::CanAcceptCell<typename CellCollection::value_type,
typename ClusterCollection::value_type>)
ClusterCollection mergeClustersImpl(CellCollection& cells) {
using Cluster = typename ClusterCollection::value_type;

Expand Down Expand Up @@ -215,6 +167,8 @@ ClusterCollection mergeClustersImpl(CellCollection& cells) {
namespace Acts::Ccl {

template <typename Cell>
requires(Acts::Ccl::HasRetrievableColumnInfo<Cell> &&
Acts::Ccl::HasRetrievableRowInfo<Cell>)
ConnectResult Connect2D<Cell>::operator()(const Cell& ref,
const Cell& iter) const {
int deltaRow = std::abs(getCellRow(ref) - getCellRow(iter));
Expand All @@ -237,7 +191,7 @@ ConnectResult Connect2D<Cell>::operator()(const Cell& ref,
return ConnectResult::eNoConn;
}

template <typename Cell>
template <Acts::Ccl::HasRetrievableColumnInfo Cell>
ConnectResult Connect1D<Cell>::operator()(const Cell& ref,
const Cell& iter) const {
int deltaCol = std::abs(getCellColumn(ref) - getCellColumn(iter));
Expand Down Expand Up @@ -267,17 +221,19 @@ void recordEquivalences(const internal::Connections<GridDim> seen,
}

template <typename CellCollection, std::size_t GridDim, typename Connect>
requires(
Acts::Ccl::HasRetrievableLabelInfo<typename CellCollection::value_type>)
void labelClusters(CellCollection& cells, Connect connect) {
using Cell = typename CellCollection::value_type;
internal::staticCheckCellType<Cell, GridDim>();

internal::DisjointSets ds{};

// Sort cells by position to enable in-order scan
std::ranges::sort(cells, internal::Compare<Cell, GridDim>());

// First pass: Allocate labels and record equivalences
for (auto it = cells.begin(); it != cells.end(); ++it) {
for (auto it = std::ranges::begin(cells); it != std::ranges::end(cells);
++it) {
const internal::Connections<GridDim> seen =
internal::getConnections<Cell, Connect, GridDim>(it, cells, connect);
if (seen.nconn == 0) {
Expand All @@ -299,13 +255,11 @@ void labelClusters(CellCollection& cells, Connect connect) {

template <typename CellCollection, typename ClusterCollection,
std::size_t GridDim = 2>
requires(GridDim == 1 || GridDim == 2) &&
Acts::Ccl::HasRetrievableLabelInfo<
typename CellCollection::value_type>
ClusterCollection mergeClusters(CellCollection& cells) {
using Cell = typename CellCollection::value_type;
using Cluster = typename ClusterCollection::value_type;
internal::staticCheckGridDim<GridDim>();
internal::staticCheckCellType<Cell, GridDim>();
internal::staticCheckClusterType<Cluster&, const Cell&>();

if constexpr (GridDim > 1) {
// Sort the cells by their cluster label, only needed if more than
// one spatial dimension
Expand All @@ -318,10 +272,6 @@ ClusterCollection mergeClusters(CellCollection& cells) {
template <typename CellCollection, typename ClusterCollection,
std::size_t GridDim, typename Connect>
ClusterCollection createClusters(CellCollection& cells, Connect connect) {
using Cell = typename CellCollection::value_type;
using Cluster = typename ClusterCollection::value_type;
internal::staticCheckCellType<Cell, GridDim>();
internal::staticCheckClusterType<Cluster&, const Cell&>();
labelClusters<CellCollection, GridDim, Connect>(cells, connect);
return mergeClusters<CellCollection, ClusterCollection, GridDim>(cells);
}
Expand Down
2 changes: 1 addition & 1 deletion Core/include/Acts/Clusterization/TimedClusterization.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// This file is part of the ACTS project.
//
// Copyright (C) 2016-2024 CERN for the benefit of the ACTS project
// Copyright (C) 2016 CERN for the benefit of the ACTS project
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
Expand Down
45 changes: 45 additions & 0 deletions Tests/UnitTests/Core/Clusterization/TimedClusterizationTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,51 @@ static void clusterAddCell(Cluster& cl, const Cell& cell) {
cl.ids.push_back(cell.id);
}

BOOST_AUTO_TEST_CASE(TimedGrid_1D_withtime) {
// 1x10 matrix
/*
X X X Y O X Y Y X X
*/
// 6 + 3 cells -> 3 + 2 clusters in total

std::vector<Cell> cells;
// X
cells.emplace_back(0ul, 0, -1, 0);
cells.emplace_back(1ul, 1, -1, 0);
cells.emplace_back(2ul, 2, -1, 0);
cells.emplace_back(3ul, 5, -1, 0);
cells.emplace_back(4ul, 8, -1, 0);
cells.emplace_back(5ul, 9, -1, 0);
// Y
cells.emplace_back(6ul, 3, 0, 1);
cells.emplace_back(7ul, 6, 1, 1);
cells.emplace_back(8ul, 7, 1, 1);

std::vector<std::vector<Identifier>> expectedResults;
expectedResults.push_back({0ul, 1ul, 2ul});
expectedResults.push_back({6ul});
expectedResults.push_back({3ul});
expectedResults.push_back({7ul, 8ul});
expectedResults.push_back({4ul, 5ul});

ClusterCollection clusters =
Acts::Ccl::createClusters<CellCollection, ClusterCollection, 1>(
cells, Acts::Ccl::TimedConnect<Cell, 1>(0.5));

BOOST_CHECK_EQUAL(5ul, clusters.size());

for (std::size_t i(0); i < clusters.size(); ++i) {
std::vector<Identifier>& timedIds = clusters[i].ids;
const std::vector<Identifier>& expected = expectedResults[i];
std::sort(timedIds.begin(), timedIds.end());
BOOST_CHECK_EQUAL(timedIds.size(), expected.size());

for (std::size_t j(0); j < timedIds.size(); ++j) {
BOOST_CHECK_EQUAL(timedIds[j], expected[j]);
}
}
}

BOOST_AUTO_TEST_CASE(TimedGrid_2D_notime) {
// 4x4 matrix
/*
Expand Down

0 comments on commit dcdeac4

Please sign in to comment.