Skip to content

Commit

Permalink
Add support for Timed Clusterization
Browse files Browse the repository at this point in the history
  • Loading branch information
cvarni committed Oct 2, 2024
1 parent 661ab12 commit 0b3beeb
Show file tree
Hide file tree
Showing 6 changed files with 402 additions and 84 deletions.
50 changes: 41 additions & 9 deletions Core/include/Acts/Clusterization/Clusterization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,26 @@

namespace Acts::Ccl {

template <typename Cell>
concept HasRetrievableColumnInfo = requires(Cell cell) {
{ getCellColumn(cell) } -> std::same_as<int>;
};

template <typename Cell>
concept HasRetrievableRowInfo = requires(Cell cell) {
{ getCellRow(cell) } -> std::same_as<int>;
};

template <typename Cell>
concept HasRetrievableLabelInfo = requires(Cell cell) {
{ getCellLabel(cell) } -> std::same_as<int&>;
};

template <typename Cell, typename Cluster>
concept CanAcceptCell = requires(Cell cell, Cluster cluster) {
{ clusterAddCell(cluster, cell) } -> std::same_as<void>;
};

using Label = int;
constexpr Label NO_LABEL = 0;

Expand All @@ -28,17 +48,21 @@ enum class ConnectResult {

// Default connection type for 2-D grids: 4- or 8-cell connectivity
template <typename Cell>
requires(Acts::Ccl::HasRetrievableColumnInfo<Cell> &&
Acts::Ccl::HasRetrievableRowInfo<Cell>)
struct Connect2D {
bool conn8;
Connect2D() : conn8{true} {}
bool conn8{true};
Connect2D() = default;
explicit Connect2D(bool commonCorner) : conn8{commonCorner} {}
ConnectResult operator()(const Cell& ref, const Cell& iter) const;
virtual ConnectResult operator()(const Cell& ref, const Cell& iter) const;
virtual ~Connect2D() = default;
};

// Default connection type for 1-D grids: 2-cell connectivity
template <typename Cell>
template <Acts::Ccl::HasRetrievableColumnInfo Cell>
struct Connect1D {
ConnectResult operator()(const Cell& ref, const Cell& iter) const;
virtual ConnectResult operator()(const Cell& ref, const Cell& iter) const;
virtual ~Connect1D() = default;
};

// Default connection type based on GridDim
Expand All @@ -49,13 +73,16 @@ struct DefaultConnect {
};

template <typename Cell>
struct DefaultConnect<Cell, 2> : public Connect2D<Cell> {
explicit DefaultConnect(bool commonCorner) : Connect2D<Cell>(commonCorner) {}
DefaultConnect() : DefaultConnect(true) {}
struct DefaultConnect<Cell, 1> : public Connect1D<Cell> {
virtual ~DefaultConnect() = default;
};

template <typename Cell>
struct DefaultConnect<Cell, 1> : public Connect1D<Cell> {};
struct DefaultConnect<Cell, 2> : public Connect2D<Cell> {
explicit DefaultConnect(bool commonCorner) : Connect2D<Cell>(commonCorner) {}
DefaultConnect() = default;
virtual ~DefaultConnect() = default;
};

/// @brief labelClusters
///
Expand All @@ -70,6 +97,8 @@ struct DefaultConnect<Cell, 1> : public Connect1D<Cell> {};
template <typename CellCollection, std::size_t GridDim = 2,
typename Connect =
DefaultConnect<typename CellCollection::value_type, GridDim>>
requires(
Acts::Ccl::HasRetrievableLabelInfo<typename CellCollection::value_type>)
void labelClusters(CellCollection& cells, Connect connect = Connect());

/// @brief mergeClusters
Expand All @@ -82,6 +111,9 @@ void labelClusters(CellCollection& cells, Connect connect = Connect());
/// @return nothing
template <typename CellCollection, typename ClusterCollection,
std::size_t GridDim>
requires(GridDim == 1 || GridDim == 2) &&
Acts::Ccl::HasRetrievableLabelInfo<
typename CellCollection::value_type>
ClusterCollection mergeClusters(CellCollection& /*cells*/);

/// @brief createClusters
Expand Down
100 changes: 25 additions & 75 deletions Core/include/Acts/Clusterization/Clusterization.ipp
Original file line number Diff line number Diff line change
Expand Up @@ -14,86 +14,34 @@

namespace Acts::Ccl::internal {

// Machinery for validating generic Cell/Cluster types at compile-time

template <typename, std::size_t, typename T = void>
struct cellTypeHasRequiredFunctions : std::false_type {};

template <typename T>
struct cellTypeHasRequiredFunctions<
T, 2,
std::void_t<decltype(getCellRow(std::declval<T>())),
decltype(getCellColumn(std::declval<T>())),
decltype(getCellLabel(std::declval<T&>()))>> : std::true_type {
};

template <typename T>
struct cellTypeHasRequiredFunctions<
T, 1,
std::void_t<decltype(getCellColumn(std::declval<T>())),
decltype(getCellLabel(std::declval<T&>()))>> : std::true_type {
};

template <typename, typename, typename T = void>
struct clusterTypeHasRequiredFunctions : std::false_type {};

template <typename T, typename U>
struct clusterTypeHasRequiredFunctions<
T, U,
std::void_t<decltype(clusterAddCell(std::declval<T>(), std::declval<U>()))>>
: std::true_type {};

template <std::size_t GridDim>
constexpr void staticCheckGridDim() {
static_assert(
GridDim == 1 || GridDim == 2,
"mergeClusters is only defined for grid dimensions of 1 or 2. ");
}

template <typename T, std::size_t GridDim>
constexpr void staticCheckCellType() {
constexpr bool hasFns = cellTypeHasRequiredFunctions<T, GridDim>();
static_assert(hasFns,
"Cell type should have the following functions: "
"'int getCellRow(const Cell&)', "
"'int getCellColumn(const Cell&)', "
"'Label& getCellLabel(Cell&)'");
}

template <typename T, typename U>
constexpr void staticCheckClusterType() {
constexpr bool hasFns = clusterTypeHasRequiredFunctions<T, U>();
static_assert(hasFns,
"Cluster type should have the following function: "
"'void clusterAddCell(Cluster&, const Cell&)'");
}

template <typename Cell, std::size_t GridDim>
struct Compare {
static_assert(GridDim != 1 && GridDim != 2,
"Only grid dimensions of 1 or 2 are supported");
};

// Comparator function object for cells, column-wise ordering
// Specialization for 2-D grid
template <typename Cell>
struct Compare<Cell, 2> {
// Specialization for 1-D grids
template <Acts::Ccl::HasRetrievableColumnInfo Cell>
struct Compare<Cell, 1> {
bool operator()(const Cell& c0, const Cell& c1) const {
int row0 = getCellRow(c0);
int row1 = getCellRow(c1);
int col0 = getCellColumn(c0);
int col1 = getCellColumn(c1);
return (col0 == col1) ? row0 < row1 : col0 < col1;
return col0 < col1;
}
};

// Specialization for 1-D grids
// Specialization for 2-D grid
template <typename Cell>
struct Compare<Cell, 1> {
requires(Acts::Ccl::HasRetrievableColumnInfo<Cell> &&
Acts::Ccl::HasRetrievableRowInfo<Cell>)
struct Compare<Cell, 2> {
bool operator()(const Cell& c0, const Cell& c1) const {
int row0 = getCellRow(c0);
int row1 = getCellRow(c1);
int col0 = getCellColumn(c0);
int col1 = getCellColumn(c1);
return col0 < col1;
return (col0 == col1) ? row0 < row1 : col0 < col1;
}
};

Expand Down Expand Up @@ -184,6 +132,10 @@ Connections<GridDim> getConnections(typename std::vector<Cell>::iterator it,
}

template <typename CellCollection, typename ClusterCollection>
requires(
Acts::Ccl::HasRetrievableLabelInfo<typename CellCollection::value_type> &&
Acts::Ccl::CanAcceptCell<typename CellCollection::value_type,
typename ClusterCollection::value_type>)
ClusterCollection mergeClustersImpl(CellCollection& cells) {
using Cluster = typename ClusterCollection::value_type;

Expand Down Expand Up @@ -215,6 +167,8 @@ ClusterCollection mergeClustersImpl(CellCollection& cells) {
namespace Acts::Ccl {

template <typename Cell>
requires(Acts::Ccl::HasRetrievableColumnInfo<Cell> &&
Acts::Ccl::HasRetrievableRowInfo<Cell>)
ConnectResult Connect2D<Cell>::operator()(const Cell& ref,
const Cell& iter) const {
int deltaRow = std::abs(getCellRow(ref) - getCellRow(iter));
Expand All @@ -237,7 +191,7 @@ ConnectResult Connect2D<Cell>::operator()(const Cell& ref,
return ConnectResult::eNoConn;
}

template <typename Cell>
template <Acts::Ccl::HasRetrievableColumnInfo Cell>
ConnectResult Connect1D<Cell>::operator()(const Cell& ref,
const Cell& iter) const {
int deltaCol = std::abs(getCellColumn(ref) - getCellColumn(iter));
Expand Down Expand Up @@ -267,17 +221,19 @@ void recordEquivalences(const internal::Connections<GridDim> seen,
}

template <typename CellCollection, std::size_t GridDim, typename Connect>
requires(
Acts::Ccl::HasRetrievableLabelInfo<typename CellCollection::value_type>)
void labelClusters(CellCollection& cells, Connect connect) {
using Cell = typename CellCollection::value_type;
internal::staticCheckCellType<Cell, GridDim>();

internal::DisjointSets ds{};

// Sort cells by position to enable in-order scan
std::ranges::sort(cells, internal::Compare<Cell, GridDim>());

// First pass: Allocate labels and record equivalences
for (auto it = cells.begin(); it != cells.end(); ++it) {
for (auto it = std::ranges::begin(cells); it != std::ranges::end(cells);
++it) {
const internal::Connections<GridDim> seen =
internal::getConnections<Cell, Connect, GridDim>(it, cells, connect);
if (seen.nconn == 0) {
Expand All @@ -299,13 +255,11 @@ void labelClusters(CellCollection& cells, Connect connect) {

template <typename CellCollection, typename ClusterCollection,
std::size_t GridDim = 2>
requires(GridDim == 1 || GridDim == 2) &&
Acts::Ccl::HasRetrievableLabelInfo<
typename CellCollection::value_type>
ClusterCollection mergeClusters(CellCollection& cells) {
using Cell = typename CellCollection::value_type;
using Cluster = typename ClusterCollection::value_type;
internal::staticCheckGridDim<GridDim>();
internal::staticCheckCellType<Cell, GridDim>();
internal::staticCheckClusterType<Cluster&, const Cell&>();

if constexpr (GridDim > 1) {
// Sort the cells by their cluster label, only needed if more than
// one spatial dimension
Expand All @@ -318,10 +272,6 @@ ClusterCollection mergeClusters(CellCollection& cells) {
template <typename CellCollection, typename ClusterCollection,
std::size_t GridDim, typename Connect>
ClusterCollection createClusters(CellCollection& cells, Connect connect) {
using Cell = typename CellCollection::value_type;
using Cluster = typename ClusterCollection::value_type;
internal::staticCheckCellType<Cell, GridDim>();
internal::staticCheckClusterType<Cluster&, const Cell&>();
labelClusters<CellCollection, GridDim, Connect>(cells, connect);
return mergeClusters<CellCollection, ClusterCollection, GridDim>(cells);
}
Expand Down
39 changes: 39 additions & 0 deletions Core/include/Acts/Clusterization/TimedClusterization.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// This file is part of the ACTS project.
//
// Copyright (C) 2016 CERN for the benefit of the ACTS project
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.

#pragma once

#include "Acts/Clusterization/Clusterization.hpp"
#include "Acts/Definitions/Algebra.hpp"

#include <limits>

namespace Acts::Ccl {

template <typename Cell>
concept HasRetrievableTimeInfo = requires(Cell cell) {
{ getCellTime(cell) } -> std::same_as<Acts::ActsScalar>;
};

template <Acts::Ccl::HasRetrievableTimeInfo Cell, std::size_t N>
struct TimedConnect : public Acts::Ccl::DefaultConnect<Cell, N> {
Acts::ActsScalar timeTollerance{std::numeric_limits<Acts::ActsScalar>::max()};

TimedConnect() = default;
TimedConnect(Acts::ActsScalar time);
TimedConnect(Acts::ActsScalar time, bool conn)
requires(N == 2);
virtual ~TimedConnect() = default;

virtual ConnectResult operator()(const Cell& ref,
const Cell& iter) const override;
};

} // namespace Acts::Ccl

#include "Acts/Clusterization/TimedClusterization.ipp"
36 changes: 36 additions & 0 deletions Core/include/Acts/Clusterization/TimedClusterization.ipp
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// This file is part of the ACTS project.
//
// Copyright (C) 2016 CERN for the benefit of the ACTS project
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.

namespace Acts::Ccl {

template <Acts::Ccl::HasRetrievableTimeInfo Cell, std::size_t N>
TimedConnect<Cell, N>::TimedConnect(Acts::ActsScalar time)
: timeTollerance(time) {}

template <Acts::Ccl::HasRetrievableTimeInfo Cell, std::size_t N>
TimedConnect<Cell, N>::TimedConnect(Acts::ActsScalar time, bool conn)
requires(N == 2)
: Acts::Ccl::DefaultConnect<Cell, N>(conn), timeTollerance(time) {}

template <Acts::Ccl::HasRetrievableTimeInfo Cell, std::size_t N>
Acts::Ccl::ConnectResult TimedConnect<Cell, N>::operator()(
const Cell& ref, const Cell& iter) const {
Acts::Ccl::ConnectResult spaceCompatibility =
Acts::Ccl::DefaultConnect<Cell, N>::operator()(ref, iter);
if (spaceCompatibility != Acts::Ccl::ConnectResult::eConn) {
return spaceCompatibility;
}

if (std::abs(getCellTime(ref) - getCellTime(iter)) < timeTollerance) {
return Acts::Ccl::ConnectResult::eConn;
}

return Acts::Ccl::ConnectResult::eNoConn;
}

} // namespace Acts::Ccl
1 change: 1 addition & 0 deletions Tests/UnitTests/Core/Clusterization/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
add_unittest(Clusterization1D ClusterizationTests1D.cpp)
add_unittest(Clusterization2D ClusterizationTests2D.cpp)
add_unittest(TimedClusterization TimedClusterizationTests.cpp)
Loading

0 comments on commit 0b3beeb

Please sign in to comment.