Skip to content

Commit

Permalink
Resolve "Sync DDC"
Browse files Browse the repository at this point in the history
  • Loading branch information
tpadioleau committed Sep 26, 2023
1 parent 27b6bb3 commit 39e0bc8
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 29 deletions.
16 changes: 16 additions & 0 deletions vendor/ddc/include/ddc/for_each.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,22 @@ inline constexpr serial_host_policy serial_host;
inline constexpr parallel_host_policy parallel_host;
inline constexpr parallel_device_policy parallel_device;

template <typename ExecSpace>
constexpr auto policy([[maybe_unused]] ExecSpace exec_space)
{
if constexpr (std::is_same_v<ExecSpace, Kokkos::Serial>) {
return ddc::policies::serial_host;
#ifdef KOKKOS_ENABLE_OPENMP
} else if constexpr (std::is_same_v<ExecSpace, Kokkos::OpenMP>) {
return ddc::policies::parallel_host;
#endif
#if defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP)
} else {
return ddc::policies::parallel_device;
#endif
}
}

} // namespace policies

/** iterates over a nD domain using the default execution policy
Expand Down
18 changes: 10 additions & 8 deletions vendor/ddc/include/ddc/pdi.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,13 @@ class PdiEvent
template <
PDI_inout_t access,
class Arithmetic,
std::enable_if_t<std::is_arithmetic_v<Arithmetic>, int> = 0>
PdiEvent& with(std::string const& name, Arithmetic& data)
std::enable_if_t<std::is_arithmetic_v<std::remove_reference_t<Arithmetic>>, int> = 0>
PdiEvent& with(std::string const& name, Arithmetic&& data)
{
static_assert(
!(access & PDI_IN) || (default_access_v<Arithmetic> & PDI_IN),
"Invalid access for constant data");
using value_type = std::remove_cv_t<Arithmetic>;
using value_type = std::remove_cv_t<std::remove_reference_t<Arithmetic>>;
PDI_share(name.c_str(), const_cast<value_type*>(&data), access);
m_names.push_back(name);
return *this;
Expand All @@ -84,14 +84,16 @@ class PdiEvent
template <class BorrowedChunk, std::enable_if_t<is_borrowed_chunk_v<BorrowedChunk>, int> = 0>
PdiEvent& with(std::string const& name, BorrowedChunk&& data)
{
return with<chunk_default_access_v<BorrowedChunk>>(name, data);
return with<chunk_default_access_v<BorrowedChunk>>(name, std::forward<BorrowedChunk>(data));
}

/// Arithmetic overload (only lvalue-ref)
template <class Arithmetic, std::enable_if_t<std::is_arithmetic_v<Arithmetic>, int> = 0>
PdiEvent& with(std::string const& name, Arithmetic& data)
/// Arithmetic overload
template <
class Arithmetic,
std::enable_if_t<std::is_arithmetic_v<std::remove_reference_t<Arithmetic>>, int> = 0>
PdiEvent& with(std::string const& name, Arithmetic&& data)
{
return with<default_access_v<Arithmetic>>(name, data);
return with<default_access_v<Arithmetic>>(name, std::forward<Arithmetic>(data));
}

/// With synonym
Expand Down
29 changes: 8 additions & 21 deletions vendor/ddc/tests/fft.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,26 +20,12 @@ using DDom = ddc::DiscreteDomain<DDim...>;
template <typename Kx>
using DFDim = ddc::PeriodicSampling<Kx>;

template <typename ExecSpace>
constexpr auto policy = [] {
if constexpr (std::is_same_v<ExecSpace, Kokkos::Serial>) {
return ddc::policies::serial_host;
}
#if fftw_omp_AVAIL
else if constexpr (std::is_same_v<ExecSpace, Kokkos::OpenMP>) {
return ddc::policies::parallel_host;
}
#endif
else {
return ddc::policies::parallel_device;
}
};

// TODO:
// - FFT multidim but according to a subset of dimensions
template <typename ExecSpace, typename MemorySpace, typename Tin, typename Tout, typename... X>
static void test_fft()
{
const ExecSpace exec_space = ExecSpace();
constexpr bool full_fft
= ddc::detail::fft::is_complex_v<Tin> && ddc::detail::fft::is_complex_v<Tout>;
const double a = -10;
Expand All @@ -57,7 +43,7 @@ static void test_fft()
ddc::Chunk _f(x_mesh, ddc::KokkosAllocator<Tin, MemorySpace>());
ddc::ChunkSpan f = _f.span_view();
ddc::for_each(
policy<ExecSpace>(),
ddc::policies::policy(exec_space),
f.domain(),
DDC_LAMBDA(DElem<DDim<X>...> const e) {
ddc::Real const xn2
Expand All @@ -70,7 +56,7 @@ static void test_fft()

ddc::Chunk Ff_alloc(k_mesh, ddc::KokkosAllocator<Tout, MemorySpace>());
ddc::ChunkSpan Ff = Ff_alloc.span_view();
ddc::fft(ExecSpace(), Ff, f_bis, {ddc::FFT_Normalization::FULL});
ddc::fft(exec_space, Ff, f_bis, {ddc::FFT_Normalization::FULL});
Kokkos::fence();

// deepcopy of Ff because FFT C2R overwrites the input
Expand All @@ -80,7 +66,7 @@ static void test_fft()

ddc::Chunk FFf_alloc(f.domain(), ddc::KokkosAllocator<Tin, MemorySpace>());
ddc::ChunkSpan FFf = FFf_alloc.span_view();
ddc::ifft(ExecSpace(), FFf, Ff_bis, {ddc::FFT_Normalization::FULL});
ddc::ifft(exec_space, FFf, Ff_bis, {ddc::FFT_Normalization::FULL});

ddc::Chunk f_host_alloc(f.domain(), ddc::HostAllocator<Tin>());
ddc::ChunkSpan f_host = f_host_alloc.span_view();
Expand Down Expand Up @@ -130,6 +116,7 @@ static void test_fft()
template <typename ExecSpace, typename MemorySpace, typename Tin, typename Tout, typename X>
static void test_fft_norm(ddc::FFT_Normalization const norm)
{
const ExecSpace exec_space = ExecSpace();
constexpr bool full_fft
= ddc::detail::fft::is_complex_v<Tin> && ddc::detail::fft::is_complex_v<Tout>;

Expand All @@ -143,7 +130,7 @@ static void test_fft_norm(ddc::FFT_Normalization const norm)
ddc::Chunk f_alloc = ddc::Chunk(x_mesh, ddc::KokkosAllocator<Tin, MemorySpace>());
ddc::ChunkSpan f = f_alloc.span_view();
ddc::for_each(
policy<ExecSpace>(),
ddc::policies::policy(exec_space),
f.domain(),
DDC_LAMBDA(DElem<DDim<X>> const e) { f(e) = static_cast<Tin>(1); });

Expand All @@ -154,7 +141,7 @@ static void test_fft_norm(ddc::FFT_Normalization const norm)

ddc::Chunk Ff_alloc = ddc::Chunk(k_mesh, ddc::KokkosAllocator<Tout, MemorySpace>());
ddc::ChunkSpan Ff = Ff_alloc.span_view();
ddc::fft(ExecSpace(), Ff, f_bis, {norm});
ddc::fft(exec_space, Ff, f_bis, {norm});
Kokkos::fence();

// deepcopy of Ff because FFT C2R overwrites the input
Expand All @@ -164,7 +151,7 @@ static void test_fft_norm(ddc::FFT_Normalization const norm)

ddc::Chunk FFf_alloc = ddc::Chunk(x_mesh, ddc::KokkosAllocator<Tin, MemorySpace>());
ddc::ChunkSpan FFf = FFf_alloc.span_view();
ddc::ifft(ExecSpace(), FFf, Ff_bis, {norm});
ddc::ifft(exec_space, FFf, Ff_bis, {norm});

double const f_sum = ddc::transform_reduce(f.domain(), 0., ddc::reducer::sum<double>(), f);

Expand Down

0 comments on commit 39e0bc8

Please sign in to comment.