Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(util): Add builder class for "chain" delegates #3696

Merged
merged 12 commits into from
Oct 9, 2024
47 changes: 31 additions & 16 deletions Core/include/Acts/Utilities/Delegate.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,20 @@ class Delegate;
///
template <typename R, typename H, DelegateType O, typename... Args>
class Delegate<R(Args...), H, O> {
public:
static constexpr DelegateType kOwnership = O;

/// Alias of the return type
using return_type = R;
using holder_type = H;
/// Alias to the function pointer type this class will store
using function_type = return_type (*)(const holder_type *, Args...);

using function_ptr_type = return_type (*)(Args...);
using signature_type = R(Args...);

using deleter_type = void (*)(const holder_type *);

private:
template <typename T, typename C>
using isSignatureCompatible =
decltype(std::declval<T &>() = std::declval<C>());
Expand Down Expand Up @@ -157,7 +159,10 @@ class Delegate<R(Args...), H, O> {
/// @note The function pointer must be ``constexpr`` for @c Delegate to accept it
/// @tparam Callable The compile-time free function pointer
template <auto Callable>
void connect() {
void connect()
requires(
Concepts::invocable_and_returns<Callable, return_type, Args && ...>)
{
m_payload.payload = nullptr;

static_assert(
Expand Down Expand Up @@ -202,6 +207,14 @@ class Delegate<R(Args...), H, O> {
m_function = callable;
}

template <typename Type>
void connect(function_type callable, const Type *instance)
requires(kOwnership == DelegateType::NonOwning)
{
m_payload.payload = instance;
m_function = callable;
}

/// Connect a member function to be called on an instance
/// @tparam Callable The compile-time member function pointer
/// @tparam Type The type of the instance the member function should be called on
Expand All @@ -210,13 +223,11 @@ class Delegate<R(Args...), H, O> {
/// it's lifetime is longer than that of @c Delegate.
template <auto Callable, typename Type>
void connect(const Type *instance)
requires(kOwnership == DelegateType::NonOwning)
{
static_assert(Concepts::invocable_and_returns<Callable, return_type, Type,
Args &&...>,
"Callable given does not correspond exactly to required call "
"signature");
requires(kOwnership == DelegateType::NonOwning &&
Concepts::invocable_and_returns<Callable, return_type, Type,
Args && ...>)

{
m_payload.payload = instance;

m_function = [](const holder_type *payload, Args... args) -> return_type {
Expand All @@ -234,13 +245,10 @@ class Delegate<R(Args...), H, O> {
/// @note @c Delegate assumes owner ship over @p instance.
template <auto Callable, typename Type>
void connect(std::unique_ptr<const Type> instance)
requires(kOwnership == DelegateType::Owning)
requires(kOwnership == DelegateType::Owning &&
Concepts::invocable_and_returns<Callable, return_type, Type,
Args && ...>)
{
static_assert(Concepts::invocable_and_returns<Callable, return_type, Type,
Args &&...>,
"Callable given does not correspond exactly to required call "
"signature");

m_payload.payload = std::unique_ptr<const holder_type, deleter_type>(
instance.release(), [](const holder_type *payload) {
const auto *concretePayload = static_cast<const Type *>(payload);
Expand All @@ -259,7 +267,9 @@ class Delegate<R(Args...), H, O> {
/// @param args The arguments to call the contained function with
/// @return Return value of the contained function
template <typename... Ts>
return_type operator()(Ts &&...args) const {
return_type operator()(Ts &&...args) const
requires(std::is_invocable_v<function_type, const holder_type *, Ts...>)
{
assert(connected() && "Delegate is not connected");
return std::invoke(m_function, m_payload.ptr(), std::forward<Ts>(args)...);
}
Expand Down Expand Up @@ -328,6 +338,11 @@ class OwningDelegate;
/// Alias for an owning delegate
template <typename R, typename H, typename... Args>
class OwningDelegate<R(Args...), H>
: public Delegate<R(Args...), H, DelegateType::Owning> {};
: public Delegate<R(Args...), H, DelegateType::Owning> {
public:
OwningDelegate() = default;
OwningDelegate(Delegate<R(Args...), H, DelegateType::Owning> &&delegate)
: Delegate<R(Args...), H, DelegateType::Owning>(std::move(delegate)) {}
};

} // namespace Acts
161 changes: 161 additions & 0 deletions Core/include/Acts/Utilities/DelegateChainBuilder.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
// This file is part of the ACTS project.
//
// Copyright (C) 2016 CERN for the benefit of the ACTS project
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.

#pragma once

#include "Acts/Utilities/Delegate.hpp"
#include "Acts/Utilities/TypeList.hpp"
#include "Acts/Utilities/TypeTag.hpp"

#include <tuple>
#include <type_traits>
#include <utility>

namespace Acts {

template <typename, typename payload_types = TypeList<>, auto... Callables>
class DelegateChainBuilder;

// @TODO: Maybe add concept requirement for default initialization of R
andiwand marked this conversation as resolved.
Show resolved Hide resolved
template <typename R, typename... payload_types, auto... callables,
typename... callable_args>
class DelegateChainBuilder<R(callable_args...), TypeList<payload_types...>,
callables...> {
using return_type =
std::conditional_t<std::is_same_v<R, void>, void,
std::array<R, sizeof...(payload_types)>>;
using delegate_type =
Delegate<return_type(callable_args...), void, DelegateType::Owning>;
using tuple_type = std::tuple<payload_types...>;

public:
template <typename, typename Ps, auto... Cs>
friend class DelegateChainBuilder;

DelegateChainBuilder() = default;

template <typename D>
DelegateChainBuilder(const D& /*unused*/) {}

template <auto Callable, typename payload_type>
constexpr auto add(payload_type payload)
requires(std::is_pointer_v<payload_type>)
{
std::tuple<payload_types..., payload_type> payloads =
std::tuple_cat(m_payloads, std::make_tuple(payload));

return DelegateChainBuilder<R(callable_args...),
TypeList<payload_types..., payload_type>,
callables..., Callable>{payloads};
}

template <auto Callable>
constexpr auto add() {
std::tuple<payload_types..., std::nullptr_t> payloads =
std::tuple_cat(m_payloads, std::make_tuple(std::nullptr_t{}));

return DelegateChainBuilder<R(callable_args...),
TypeList<payload_types..., std::nullptr_t>,
callables..., Callable>{payloads};
}

delegate_type build()
requires(sizeof...(callables) > 0)
{
auto block = std::make_unique<const DispatchBlock>(m_payloads);
delegate_type delegate;
delegate.template connect<&DispatchBlock::dispatch>(std::move(block));
return delegate;
}

void store(delegate_type& delegate)
requires(sizeof...(callables) > 0)
{
auto block = std::make_unique<const DispatchBlock>(m_payloads);
delegate.template connect<&DispatchBlock::dispatch>(std::move(block));
}

void store(Delegate<R(callable_args...)>& delegate)
requires(sizeof...(callables) == 1)
{
constexpr auto callable =
DispatchBlock::template findCallable<0, 0, callables...>();
delegate.template connect<callable>(std::get<0>(m_payloads));

// auto block = std::make_unique<const DispatchBlock>(m_payloads);
// delegate.template connect<&DispatchBlock::dispatch>(std::move(block));
}

private:
DelegateChainBuilder(std::tuple<payload_types...> payloads)
: m_payloads(payloads) {}

struct DispatchBlock {
template <std::size_t I, std::size_t J, auto head, auto... tail>
static constexpr auto findCallable() {
if constexpr (I == J) {
return head;
} else {
return findCallable<I, J + 1, tail...>();
}
}

template <std::size_t I = 0, typename result_ptr>
static constexpr auto invoke(result_ptr result, const tuple_type* payloads,
callable_args... args) {
const auto& callable = findCallable<I, 0, callables...>();

if constexpr (!std::is_same_v<std::tuple_element_t<I, tuple_type>,
std::nullptr_t>) {
auto payload = std::get<I>(*payloads);

if constexpr (!std::is_same_v<result_ptr, std::nullptr_t>) {
std::get<I>(*result) = std::invoke(
callable, payload, std::forward<callable_args>(args)...);
} else {
std::invoke(callable, payload, std::forward<callable_args>(args)...);
}

} else {
if constexpr (!std::is_same_v<result_ptr, std::nullptr_t>) {
std::get<I>(*result) =
std::invoke(callable, std::forward<callable_args>(args)...);
} else {
std::invoke(callable, std::forward<callable_args>(args)...);
}
}

if constexpr (I < sizeof...(payload_types) - 1) {
invoke<I + 1>(result, payloads, std::forward<callable_args>(args)...);
}
}

DispatchBlock(tuple_type payloads) : m_payloads(std::move(payloads)) {}

tuple_type m_payloads{};

auto dispatch(callable_args... args) const {
if constexpr (std::is_same_v<R, void>) {
invoke(nullptr, &m_payloads, std::forward<callable_args>(args)...);
} else {
return_type result{};
invoke(&result, &m_payloads, std::forward<callable_args>(args)...);
return result;
}
}
};

private:
tuple_type m_payloads{};
};

template <typename D>
DelegateChainBuilder(const D& /*unused*/)
-> DelegateChainBuilder<typename D::signature_type>;

} // namespace Acts
1 change: 1 addition & 0 deletions Tests/UnitTests/Core/Utilities/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ add_unittest(Result ResultTests.cpp)
add_unittest(TypeList TypeListTests.cpp)
add_unittest(UnitVectors UnitVectorsTests.cpp)
add_unittest(Delegate DelegateTests.cpp)
add_unittest(DelegateChainBuilder DelegateChainBuilderTests.cpp)
add_unittest(HashedString HashedStringTests.cpp)
if(ACTS_BUILD_CUDA_FEATURES)
add_unittest(Cuda CudaTests.cu)
Expand Down
118 changes: 118 additions & 0 deletions Tests/UnitTests/Core/Utilities/DelegateChainBuilderTests.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
// This file is part of the ACTS project.
//
// Copyright (C) 2016 CERN for the benefit of the ACTS project
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.

#include <boost/test/unit_test.hpp>
#include <boost/test/unit_test_suite.hpp>

#include "Acts/Utilities/DelegateChainBuilder.hpp"

using namespace Acts;

struct AddTo {
int value = 0;

void add(int &x) const { x += value; }
};

void addFive(int &x) {
x += 5;
}

BOOST_AUTO_TEST_SUITE(DelegateChainBuilderTests)

BOOST_AUTO_TEST_CASE(DelegateChainBuilderAdd) {
AddTo a1{1}, a2{2}, a3{3};
int x = 0;

// Basic building
OwningDelegate<void(int &)> chain = DelegateChainBuilder<void(int &)>{}
.add<&AddTo::add>(&a1)
.add<&addFive>()
.add<&AddTo::add>(&a2)
.add<&AddTo::add>(&a3)
.build();
paulgessinger marked this conversation as resolved.
Show resolved Hide resolved
chain(x);
BOOST_CHECK_EQUAL(x, 11);

chain.disconnect();

// In case of no return types, we can rebind the owning delegate with a chain
// of different size
chain = DelegateChainBuilder<void(int &)>{}
.add<&AddTo::add>(&a1)
.add<&addFive>()
.add<&AddTo::add>(&a3)
.build();

x = 0;

chain(x);
BOOST_CHECK_EQUAL(x, 9);

// CTAD helper from delegate type
chain = DelegateChainBuilder{chain}
.add<&AddTo::add>(&a1)
.add<&addFive>()
.add<&AddTo::add>(&a3)
.build();

x = 0;

chain(x);
BOOST_CHECK_EQUAL(x, 9);

Delegate<void(int &)> nonOwning;

// In case of a single callable, we can store it in a non-owning delegate
DelegateChainBuilder<void(int &)>{}.add<&AddTo::add>(&a1).store(nonOwning);

x = 0;
nonOwning(x);
BOOST_CHECK_EQUAL(x, 1);
}

struct GetInt {
int value;

int get() const { return value; }
};

int getSix() {
return 6;
}

BOOST_AUTO_TEST_CASE(DelegateChainBuilderReturn) {
GetInt g1{1}, g2{2}, g3{3};

Delegate<std::array<int, 4>(), void, DelegateType::Owning> chain =
andiwand marked this conversation as resolved.
Show resolved Hide resolved
DelegateChainBuilder<int()>{}
.add<&GetInt::get>(&g1)
.add<&getSix>()
.add<&GetInt::get>(&g2)
.add<&GetInt::get>(&g3)
.build();

auto results = chain();
std::vector<int> expected = {1, 6, 2, 3};
BOOST_CHECK_EQUAL_COLLECTIONS(results.begin(), results.end(),
expected.begin(), expected.end());

Delegate<std::array<int, 3>(), void, DelegateType::Owning> delegate;
DelegateChainBuilder<int()>{}
.add<&GetInt::get>(&g1)
.add<&getSix>()
.add<&GetInt::get>(&g3)
.store(delegate);

auto results2 = delegate();
expected = {1, 6, 3};
BOOST_CHECK_EQUAL_COLLECTIONS(results2.begin(), results2.end(),
expected.begin(), expected.end());
}

BOOST_AUTO_TEST_SUITE_END()
Loading