Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add support for Timed Clusterization #3654

Merged
merged 7 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 41 additions & 9 deletions Core/include/Acts/Clusterization/Clusterization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,26 @@

namespace Acts::Ccl {

template <typename Cell>
concept HasRetrievableColumnInfo = requires(Cell cell) {
{ getCellColumn(cell) } -> std::same_as<int>;
};

template <typename Cell>
concept HasRetrievableRowInfo = requires(Cell cell) {
{ getCellRow(cell) } -> std::same_as<int>;
};

template <typename Cell>
concept HasRetrievableLabelInfo = requires(Cell cell) {
{ getCellLabel(cell) } -> std::same_as<int&>;
};

template <typename Cell, typename Cluster>
concept CanAcceptCell = requires(Cell cell, Cluster cluster) {
{ clusterAddCell(cluster, cell) } -> std::same_as<void>;
};

using Label = int;
constexpr Label NO_LABEL = 0;

Expand All @@ -28,17 +48,21 @@ enum class ConnectResult {

// Default connection type for 2-D grids: 4- or 8-cell connectivity
template <typename Cell>
requires(Acts::Ccl::HasRetrievableColumnInfo<Cell> &&
Acts::Ccl::HasRetrievableRowInfo<Cell>)
struct Connect2D {
bool conn8;
Connect2D() : conn8{true} {}
bool conn8{true};
Connect2D() = default;
explicit Connect2D(bool commonCorner) : conn8{commonCorner} {}
ConnectResult operator()(const Cell& ref, const Cell& iter) const;
virtual ConnectResult operator()(const Cell& ref, const Cell& iter) const;
virtual ~Connect2D() = default;
};

// Default connection type for 1-D grids: 2-cell connectivity
template <typename Cell>
template <Acts::Ccl::HasRetrievableColumnInfo Cell>
struct Connect1D {
ConnectResult operator()(const Cell& ref, const Cell& iter) const;
virtual ConnectResult operator()(const Cell& ref, const Cell& iter) const;
virtual ~Connect1D() = default;
};

// Default connection type based on GridDim
Expand All @@ -49,13 +73,16 @@ struct DefaultConnect {
};

template <typename Cell>
struct DefaultConnect<Cell, 2> : public Connect2D<Cell> {
explicit DefaultConnect(bool commonCorner) : Connect2D<Cell>(commonCorner) {}
DefaultConnect() : DefaultConnect(true) {}
struct DefaultConnect<Cell, 1> : public Connect1D<Cell> {
~DefaultConnect() override = default;
};

template <typename Cell>
struct DefaultConnect<Cell, 1> : public Connect1D<Cell> {};
struct DefaultConnect<Cell, 2> : public Connect2D<Cell> {
explicit DefaultConnect(bool commonCorner) : Connect2D<Cell>(commonCorner) {}
DefaultConnect() = default;
~DefaultConnect() override = default;
};

/// @brief labelClusters
///
Expand All @@ -70,6 +97,8 @@ struct DefaultConnect<Cell, 1> : public Connect1D<Cell> {};
template <typename CellCollection, std::size_t GridDim = 2,
typename Connect =
DefaultConnect<typename CellCollection::value_type, GridDim>>
requires(
Acts::Ccl::HasRetrievableLabelInfo<typename CellCollection::value_type>)
void labelClusters(CellCollection& cells, Connect connect = Connect());

/// @brief mergeClusters
Expand All @@ -82,6 +111,9 @@ void labelClusters(CellCollection& cells, Connect connect = Connect());
/// @return nothing
template <typename CellCollection, typename ClusterCollection,
std::size_t GridDim>
requires(GridDim == 1 || GridDim == 2) &&
Acts::Ccl::HasRetrievableLabelInfo<
typename CellCollection::value_type>
ClusterCollection mergeClusters(CellCollection& /*cells*/);

/// @brief createClusters
Expand Down
100 changes: 25 additions & 75 deletions Core/include/Acts/Clusterization/Clusterization.ipp
Original file line number Diff line number Diff line change
Expand Up @@ -14,86 +14,34 @@

namespace Acts::Ccl::internal {

// Machinery for validating generic Cell/Cluster types at compile-time

template <typename, std::size_t, typename T = void>
struct cellTypeHasRequiredFunctions : std::false_type {};

template <typename T>
struct cellTypeHasRequiredFunctions<
T, 2,
std::void_t<decltype(getCellRow(std::declval<T>())),
decltype(getCellColumn(std::declval<T>())),
decltype(getCellLabel(std::declval<T&>()))>> : std::true_type {
};

template <typename T>
struct cellTypeHasRequiredFunctions<
T, 1,
std::void_t<decltype(getCellColumn(std::declval<T>())),
decltype(getCellLabel(std::declval<T&>()))>> : std::true_type {
};

template <typename, typename, typename T = void>
struct clusterTypeHasRequiredFunctions : std::false_type {};

template <typename T, typename U>
struct clusterTypeHasRequiredFunctions<
T, U,
std::void_t<decltype(clusterAddCell(std::declval<T>(), std::declval<U>()))>>
: std::true_type {};

template <std::size_t GridDim>
constexpr void staticCheckGridDim() {
static_assert(
GridDim == 1 || GridDim == 2,
"mergeClusters is only defined for grid dimensions of 1 or 2. ");
}

template <typename T, std::size_t GridDim>
constexpr void staticCheckCellType() {
constexpr bool hasFns = cellTypeHasRequiredFunctions<T, GridDim>();
static_assert(hasFns,
"Cell type should have the following functions: "
"'int getCellRow(const Cell&)', "
"'int getCellColumn(const Cell&)', "
"'Label& getCellLabel(Cell&)'");
}

template <typename T, typename U>
constexpr void staticCheckClusterType() {
constexpr bool hasFns = clusterTypeHasRequiredFunctions<T, U>();
static_assert(hasFns,
"Cluster type should have the following function: "
"'void clusterAddCell(Cluster&, const Cell&)'");
}

template <typename Cell, std::size_t GridDim>
struct Compare {
static_assert(GridDim != 1 && GridDim != 2,
"Only grid dimensions of 1 or 2 are supported");
};

// Comparator function object for cells, column-wise ordering
// Specialization for 2-D grid
template <typename Cell>
struct Compare<Cell, 2> {
// Specialization for 1-D grids
template <Acts::Ccl::HasRetrievableColumnInfo Cell>
struct Compare<Cell, 1> {
bool operator()(const Cell& c0, const Cell& c1) const {
int row0 = getCellRow(c0);
int row1 = getCellRow(c1);
int col0 = getCellColumn(c0);
int col1 = getCellColumn(c1);
return (col0 == col1) ? row0 < row1 : col0 < col1;
return col0 < col1;
}
};

// Specialization for 1-D grids
// Specialization for 2-D grid
template <typename Cell>
struct Compare<Cell, 1> {
requires(Acts::Ccl::HasRetrievableColumnInfo<Cell> &&
Acts::Ccl::HasRetrievableRowInfo<Cell>)
struct Compare<Cell, 2> {
bool operator()(const Cell& c0, const Cell& c1) const {
int row0 = getCellRow(c0);
int row1 = getCellRow(c1);
int col0 = getCellColumn(c0);
int col1 = getCellColumn(c1);
return col0 < col1;
return (col0 == col1) ? row0 < row1 : col0 < col1;
}
};

Expand Down Expand Up @@ -184,6 +132,10 @@ Connections<GridDim> getConnections(typename std::vector<Cell>::iterator it,
}

template <typename CellCollection, typename ClusterCollection>
requires(
Acts::Ccl::HasRetrievableLabelInfo<typename CellCollection::value_type> &&
Acts::Ccl::CanAcceptCell<typename CellCollection::value_type,
typename ClusterCollection::value_type>)
ClusterCollection mergeClustersImpl(CellCollection& cells) {
using Cluster = typename ClusterCollection::value_type;

Expand Down Expand Up @@ -215,6 +167,8 @@ ClusterCollection mergeClustersImpl(CellCollection& cells) {
namespace Acts::Ccl {

template <typename Cell>
requires(Acts::Ccl::HasRetrievableColumnInfo<Cell> &&
Acts::Ccl::HasRetrievableRowInfo<Cell>)
ConnectResult Connect2D<Cell>::operator()(const Cell& ref,
const Cell& iter) const {
int deltaRow = std::abs(getCellRow(ref) - getCellRow(iter));
Expand All @@ -237,7 +191,7 @@ ConnectResult Connect2D<Cell>::operator()(const Cell& ref,
return ConnectResult::eNoConn;
}

template <typename Cell>
template <Acts::Ccl::HasRetrievableColumnInfo Cell>
ConnectResult Connect1D<Cell>::operator()(const Cell& ref,
const Cell& iter) const {
int deltaCol = std::abs(getCellColumn(ref) - getCellColumn(iter));
Expand Down Expand Up @@ -267,17 +221,19 @@ void recordEquivalences(const internal::Connections<GridDim> seen,
}

template <typename CellCollection, std::size_t GridDim, typename Connect>
requires(
Acts::Ccl::HasRetrievableLabelInfo<typename CellCollection::value_type>)
void labelClusters(CellCollection& cells, Connect connect) {
using Cell = typename CellCollection::value_type;
internal::staticCheckCellType<Cell, GridDim>();

internal::DisjointSets ds{};

// Sort cells by position to enable in-order scan
std::ranges::sort(cells, internal::Compare<Cell, GridDim>());

// First pass: Allocate labels and record equivalences
for (auto it = cells.begin(); it != cells.end(); ++it) {
for (auto it = std::ranges::begin(cells); it != std::ranges::end(cells);
++it) {
const internal::Connections<GridDim> seen =
internal::getConnections<Cell, Connect, GridDim>(it, cells, connect);
if (seen.nconn == 0) {
Expand All @@ -299,13 +255,11 @@ void labelClusters(CellCollection& cells, Connect connect) {

template <typename CellCollection, typename ClusterCollection,
std::size_t GridDim = 2>
requires(GridDim == 1 || GridDim == 2) &&
Acts::Ccl::HasRetrievableLabelInfo<
typename CellCollection::value_type>
ClusterCollection mergeClusters(CellCollection& cells) {
using Cell = typename CellCollection::value_type;
using Cluster = typename ClusterCollection::value_type;
internal::staticCheckGridDim<GridDim>();
internal::staticCheckCellType<Cell, GridDim>();
internal::staticCheckClusterType<Cluster&, const Cell&>();

if constexpr (GridDim > 1) {
// Sort the cells by their cluster label, only needed if more than
// one spatial dimension
Expand All @@ -318,10 +272,6 @@ ClusterCollection mergeClusters(CellCollection& cells) {
template <typename CellCollection, typename ClusterCollection,
std::size_t GridDim, typename Connect>
ClusterCollection createClusters(CellCollection& cells, Connect connect) {
using Cell = typename CellCollection::value_type;
using Cluster = typename ClusterCollection::value_type;
internal::staticCheckCellType<Cell, GridDim>();
internal::staticCheckClusterType<Cluster&, const Cell&>();
labelClusters<CellCollection, GridDim, Connect>(cells, connect);
return mergeClusters<CellCollection, ClusterCollection, GridDim>(cells);
}
Expand Down
38 changes: 38 additions & 0 deletions Core/include/Acts/Clusterization/TimedClusterization.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// This file is part of the ACTS project.
//
// Copyright (C) 2016 CERN for the benefit of the ACTS project
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.

#pragma once

#include "Acts/Clusterization/Clusterization.hpp"
#include "Acts/Definitions/Algebra.hpp"

#include <limits>

namespace Acts::Ccl {

template <typename Cell>
concept HasRetrievableTimeInfo = requires(Cell cell) {
{ getCellTime(cell) } -> std::same_as<Acts::ActsScalar>;
};

template <Acts::Ccl::HasRetrievableTimeInfo Cell, std::size_t N>
struct TimedConnect : public Acts::Ccl::DefaultConnect<Cell, N> {
Acts::ActsScalar timeTollerance{std::numeric_limits<Acts::ActsScalar>::max()};
CarloVarni marked this conversation as resolved.
Show resolved Hide resolved

TimedConnect() = default;
TimedConnect(Acts::ActsScalar time);
TimedConnect(Acts::ActsScalar time, bool conn)
CarloVarni marked this conversation as resolved.
Show resolved Hide resolved
requires(N == 2);
CarloVarni marked this conversation as resolved.
Show resolved Hide resolved
~TimedConnect() override = default;

ConnectResult operator()(const Cell& ref, const Cell& iter) const override;
};

} // namespace Acts::Ccl

#include "Acts/Clusterization/TimedClusterization.ipp"
36 changes: 36 additions & 0 deletions Core/include/Acts/Clusterization/TimedClusterization.ipp
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// This file is part of the ACTS project.
//
// Copyright (C) 2016 CERN for the benefit of the ACTS project
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.

namespace Acts::Ccl {

template <Acts::Ccl::HasRetrievableTimeInfo Cell, std::size_t N>
TimedConnect<Cell, N>::TimedConnect(Acts::ActsScalar time)
: timeTollerance(time) {}

template <Acts::Ccl::HasRetrievableTimeInfo Cell, std::size_t N>
TimedConnect<Cell, N>::TimedConnect(Acts::ActsScalar time, bool conn)
CarloVarni marked this conversation as resolved.
Show resolved Hide resolved
requires(N == 2)
: Acts::Ccl::DefaultConnect<Cell, N>(conn), timeTollerance(time) {}

template <Acts::Ccl::HasRetrievableTimeInfo Cell, std::size_t N>
Acts::Ccl::ConnectResult TimedConnect<Cell, N>::operator()(
const Cell& ref, const Cell& iter) const {
Acts::Ccl::ConnectResult spaceCompatibility =
Acts::Ccl::DefaultConnect<Cell, N>::operator()(ref, iter);
if (spaceCompatibility != Acts::Ccl::ConnectResult::eConn) {
return spaceCompatibility;
}

if (std::abs(getCellTime(ref) - getCellTime(iter)) < timeTollerance) {
return Acts::Ccl::ConnectResult::eConn;
}

return Acts::Ccl::ConnectResult::eNoConn;
}

} // namespace Acts::Ccl
1 change: 1 addition & 0 deletions Tests/UnitTests/Core/Clusterization/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
add_unittest(Clusterization1D ClusterizationTests1D.cpp)
add_unittest(Clusterization2D ClusterizationTests2D.cpp)
add_unittest(TimedClusterization TimedClusterizationTests.cpp)
Loading
Loading