Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TypeMultiplier, MakeZeroTuple and IdentityTuple #3718

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions Src/Base/AMReX_Reduce.H
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,14 @@ namespace Reduce::detail {

template <std::size_t I, typename T, typename P>
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
void for_each_init (T& t)
constexpr void for_each_init (T& t)
{
P().init(amrex::get<I>(t));
}

template <std::size_t I, typename T, typename P, typename P1, typename... Ps>
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
void for_each_init (T& t)
constexpr void for_each_init (T& t)
{
P().init(amrex::get<I>(t));
for_each_init<I+1,T,P1,Ps...>(t);
Expand Down Expand Up @@ -1275,6 +1275,34 @@ bool AnyOf (Box const& box, P&&pred)

#endif

/**
* \brief Return a GpuTuple containing the identity element for each operation in ReduceOps.
* For example 0, +inf and -inf for ReduceOpSum, ReduceOpMin and ReduceOpMax respectively.
*/
template <typename... Ts, typename... Ps>
AMREX_GPU_HOST_DEVICE
constexpr GpuTuple<Ts...>
IdentityTuple (GpuTuple<Ts...>, ReduceOps<Ps...>) noexcept
{
GpuTuple<Ts...> r{};
Reduce::detail::for_each_init<0, decltype(r), Ps...>(r);
return r;
}

/**
* \brief Return a GpuTuple containing the identity element for each ReduceOp in TypeList.
* For example 0, +inf and -inf for ReduceOpSum, ReduceOpMin and ReduceOpMax respectively.
*/
template <typename... Ts, typename... Ps>
AMREX_GPU_HOST_DEVICE
constexpr GpuTuple<Ts...>
IdentityTuple (GpuTuple<Ts...>, TypeList<Ps...>) noexcept
{
GpuTuple<Ts...> r{};
Reduce::detail::for_each_init<0, decltype(r), Ps...>(r);
return r;
}

}

#endif
14 changes: 14 additions & 0 deletions Src/Base/AMReX_Tuple.H
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,20 @@ ForwardAsTuple (Ts&&... args) noexcept
return GpuTuple<Ts&&...>(std::forward<Ts>(args)...);
}

// MakeZeroTuple

/**
* \brief Return a GpuTuple containing all zeros.
* Note that a default-constructed GpuTuple can have uninitialized values.
*/
template <typename... Ts>
AMREX_GPU_HOST_DEVICE
constexpr GpuTuple<Ts...>
MakeZeroTuple (GpuTuple<Ts...>) noexcept
{
return GpuTuple<Ts...>(static_cast<Ts>(0)...);
}

}

#endif /*AMREX_TUPLE_H_*/
49 changes: 48 additions & 1 deletion Src/Base/AMReX_TypeList.H
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ ForEach (TypeList<Ts...>, F&& f)
// dst and src are either MultiFab or fMultiFab
auto tt = CartesianProduct(TypeList<MultiFab,fMultiFab>{},
TypeList<MultiFab,fMultiFab>{});
bool r = ForEachUtil(tt, [&] (auto t) -> bool
bool r = ForEachUntil(tt, [&] (auto t) -> bool
{
using MF0 = TypeAt<0,decltype(t)>;
using MF1 = TypeAt<1,decltype(t)>;
Expand Down Expand Up @@ -151,6 +151,53 @@ constexpr auto CartesianProduct (Ls...) {
return (TypeList<TypeList<>>{} * ... * Ls{});
}

namespace detail {
// return TypeList<T, T, T, T, ... (N times)> by using the fast power algorithm
template <class T, std::size_t N>
constexpr auto SingleTypeMultiplier_impl () {
if constexpr (N == 0) {
return TypeList<>{};
} else if constexpr (N == 1) {
return TypeList<T>{};
} else if constexpr (N % 2 == 0) {
return SingleTypeMultiplier_impl<T, N / 2>() + SingleTypeMultiplier_impl<T, N / 2>();
} else {
return SingleTypeMultiplier_impl<T, N - 1>() + TypeList<T>{};
}
}

// overload of SingleTypeMultiplier for multiple types:
// convert T[N] to T, T, T, T, ... (N times with N >= 1)
template <class T, std::size_t N>
constexpr auto SingleTypeMultiplier (const T (&)[N]) {
return SingleTypeMultiplier_impl<T, N>();
}

// overload of SingleTypeMultiplier for one regular type
template <class T>
constexpr auto SingleTypeMultiplier (T) {
return TypeList<T>{};
}

// apply the types of the input TypeList as template arguments to TParam
template <template <class...> class TParam, class... Args>
constexpr auto TApply (TypeList<Args...>) {
return TypeList<TParam<Args...>>{};
}
}

/**
* \brief Return the first template argument with the later arguments applied to it.
* Types of the form T[N] are expanded to T, T, T, T, ... (N times with N >= 1).
*
* For example, TypeMultiplier<ReduceData, Real[4], int[2], Long>
* is an alias to the type ReduceData<Real, Real, Real, Real, int, int, Long>.
*/
template <template <class...> class TParam, class... Types>
using TypeMultiplier = TypeAt<0, decltype(detail::TApply<TParam>(
(TypeList<>{} + ... + detail::SingleTypeMultiplier(Types{}))
))>;

}

#endif
Loading