Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor changes to the MPI routines #84

Merged
merged 4 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions c++/nda/map.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,8 @@ namespace nda {
* @return Result of the functor applied to the scalar arguments.
*/
template <Scalar T0, Scalar... Ts>
auto operator()(T0 a0, Ts... as) const {
return f(a0, as...);
auto operator()(T0 t0, Ts... ts) const {
return f(t0, ts...);
}
};

Expand Down
12 changes: 1 addition & 11 deletions c++/nda/mpi/broadcast.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,7 @@ namespace nda {
* - the array/view is not contiguous with positive strides or
* - one of the MPI calls fails.
*
* @code{.cpp}
* // create an array on all processes
* nda::array<int, 2> A(3, 4);
*
* // ...
* // fill array on root process
* // ...
*
* // broadcast the array to all processes
* mpi::broadcast(A);
* @endcode
* See @ref ex6_p1 for an example.
*
* @tparam A nda::basic_array or nda::basic_array_view type.
* @param a Array/view to be broadcasted from/into.
Expand Down
208 changes: 103 additions & 105 deletions c++/nda/mpi/gather.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,54 @@ namespace nda::detail {
return std::make_pair(dims, gathered_size);
}

// Helper function that (all)gathers arrays/views.
} // namespace nda::detail

namespace nda {

/**
* @addtogroup av_mpi
* @{
*/

/**
* @brief Implementation of an MPI gather for nda::basic_array or nda::basic_array_view types using a C-style API.
*
* @details The function gathers C-ordered input arrays/views from all processes in the given communicator and
* makes the result available on the root process (`all == false`) or on all processes (`all == true`). The
* arrays/views are joined along the first dimension.
*
* It is expected that all input arrays/views have the same shape on all processes except for the first dimension. The
* function throws an exception, if
* - the input array/view is not contiguous with positive strides,
* - the output array/view is not contiguous with positive strides on receiving ranks,
* - the output view does not have the correct shape on receiving ranks or
* - any of the MPI calls fails.
*
* The input arrays/views are simply concatenated along their first dimension. The content of the output array/view
* depends on the MPI rank and whether it receives the data or not:
* - On receiving ranks, it contains the gathered data and has a shape that is the same as the shape of the input
* array/view except for the first dimension, which is the sum of the extents of all input arrays/views along the
* first dimension.
* - On non-receiving ranks, the output array/view is ignored and left unchanged.
*
* If `mpi::has_env` is false or if the communicator size is < 2, it simply copies the input array/view to the output
* array/view.
*
* @tparam A1 nda::basic_array or nda::basic_array_view type with C-layout.
* @tparam A2 nda::basic_array or nda::basic_array_view type with C-layout.
* @param a_in Array/view to be gathered.
* @param a_out Array/view to gather into.
* @param comm `mpi::communicator` object.
* @param root Rank of the root process.
* @param all Should all processes receive the result of the gather.
*/
template <typename A1, typename A2>
requires(is_regular_or_view_v<A1> and std::decay_t<A1>::is_stride_order_C()
and is_regular_or_view_v<A2> and std::decay_t<A2>::is_stride_order_C())
void mpi_gather_impl(A1 const &a_in, A2 &&a_out, mpi::communicator comm = {}, int root = 0, bool all = false) { // NOLINT
void mpi_gather_capi(A1 const &a_in, A2 &&a_out, mpi::communicator comm = {}, int root = 0, bool all = false) { // NOLINT
// check the shape of the input arrays/views
EXPECTS_WITH_MESSAGE(detail::have_mpi_equal_shapes(a_in(nda::range(1), nda::ellipsis{}), comm),
"Error in nda::detail::mpi_gather_impl: Shapes of arrays/views must be equal save the first one");
"Error in nda::mpi_gather_capi: Shapes of arrays/views must be equal save the first one");

// simply copy if there is no active MPI environment or if the communicator size is < 2
if (not mpi::has_env || comm.size() < 2) {
Expand All @@ -68,17 +108,17 @@ namespace nda::detail {
}

// check if the input arrays/views can be used in the MPI call
detail::check_layout_mpi_compatible(a_in, "detail::mpi_gather_impl");
detail::check_layout_mpi_compatible(a_in, "mpi_gather_capi");

// get output shape, resize or check the output array/view and prepare output span
auto [dims, gathered_size] = mpi_gather_shape_impl(a_in, comm, root, all);
auto [dims, gathered_size] = detail::mpi_gather_shape_impl(a_in, comm, root, all);
auto a_out_span = std::span{a_out.data(), 0};
if (all || (comm.rank() == root)) {
// check if the output array/view can be used in the MPI call
check_layout_mpi_compatible(a_out, "detail::mpi_gather_impl");
detail::check_layout_mpi_compatible(a_out, "mpi_gather_capi");

// resize/check the size of the output array/view
nda::resize_or_check_if_view(a_out, dims);
resize_or_check_if_view(a_out, dims);

// prepare the output span
a_out_span = std::span{a_out.data(), static_cast<std::size_t>(a_out.size())};
Expand All @@ -89,7 +129,61 @@ namespace nda::detail {
mpi::gather_range(a_in_span, a_out_span, gathered_size, comm, root, all);
}

} // namespace nda::detail
/**
* @brief Implementation of a lazy MPI gather for nda::basic_array or nda::basic_array_view types.
*
* @details This function is lazy, i.e. it returns an mpi::lazy<mpi::tag::gather, A> object without performing the
* actual MPI operation. Since the returned object models an nda::ArrayInitializer, it can be used to
* initialize/assign to nda::basic_array and nda::basic_array_view objects.
*
* The behavior is otherwise similar to nda::mpi_gather.
*
* @warning MPI calls are done in the `invoke` and `shape` methods of the `mpi::lazy` object. If one rank calls one of
* these methods, all ranks in the communicator need to call the same method. Otherwise, the program will deadlock.
*
* @tparam A nda::basic_array or nda::basic_array_view type with C-layout.
* @param a Array/view to be gathered.
* @param comm `mpi::communicator` object.
* @param root Rank of the root process.
* @param all Should all processes receive the result of the gather.
* @return An mpi::lazy<mpi::tag::gather, A> object modelling an nda::ArrayInitializer.
*/
template <typename A>
requires(is_regular_or_view_v<A> and std::decay_t<A>::is_stride_order_C())
auto lazy_mpi_gather(A &&a, mpi::communicator comm = {}, int root = 0, bool all = false) {
return mpi::lazy<mpi::tag::gather, A>{std::forward<A>(a), comm, root, all};
}

/**
* @brief Implementation of an MPI gather for nda::basic_array or nda::basic_array_view types.
*
* @details The function gathers C-ordered input arrays/views from all processes in the given communicator and
* makes the result available on the root process (`all == false`) or on all processes (`all == true`). The
* arrays/views are joined along the first dimension.
*
* It simply constructs an empty array and then calls nda::mpi_gather_capi.
*
* See @ref ex6_p2 for examples.
*
* @tparam A nda::basic_array or nda::basic_array_view type with C-layout.
* @param a Array/view to be gathered.
* @param comm `mpi::communicator` object.
* @param root Rank of the root process.
* @param all Should all processes receive the result of the gather.
* @return An nda::basic_array object with the result of the gathering.
*/
template <typename A>
requires(is_regular_or_view_v<A> and std::decay_t<A>::is_stride_order_C())
auto mpi_gather(A const &a, mpi::communicator comm = {}, int root = 0, bool all = false) {
using return_t = get_regular_t<A>;
return_t a_out;
mpi_gather_capi(a, a_out, comm, root, all);
return a_out;
}

/** @} */

} // namespace nda

/**
* @ingroup av_mpi
Expand Down Expand Up @@ -158,102 +252,6 @@ struct mpi::lazy<mpi::tag::gather, A> {
template <nda::Array T>
requires(std::decay_t<T>::is_stride_order_C())
void invoke(T &&target) const { // NOLINT (temporary views are allowed here)
nda::detail::mpi_gather_impl(rhs, target, comm, root, all);
nda::mpi_gather_capi(rhs, target, comm, root, all);
}
};

namespace nda {

/**
* @addtogroup av_mpi
* @{
*/

/**
* @brief Implementation of a lazy MPI gather for nda::basic_array or nda::basic_array_view types.
*
* @details This function is lazy, i.e. it returns an mpi::lazy<mpi::tag::gather, A> object without performing the
* actual MPI operation. Since the returned object models an nda::ArrayInitializer, it can be used to
* initialize/assign to nda::basic_array and nda::basic_array_view objects:
*
* @code{.cpp}
* // create an array on all processes
* nda::array<int, 2> A(3, 4);
*
* // ...
* // fill array on each process
* // ...
*
* // gather the arrays on the root process
* nda::array<int, 2> B = nda::lazy_mpi_gather(A);
* @endcode
*
* The behavior is otherwise identical to nda::mpi_gather.
*
* @warning MPI calls are done in the `invoke` and `shape` methods of the `mpi::lazy` object. If one rank calls one of
* these methods, all ranks in the communicator need to call the same method. Otherwise, the program will deadlock.
*
* @tparam A nda::basic_array or nda::basic_array_view type with C-layout.
* @param a Array/view to be gathered.
* @param comm `mpi::communicator` object.
* @param root Rank of the root process.
* @param all Should all processes receive the result of the gather.
* @return An mpi::lazy<mpi::tag::gather, A> object modelling an nda::ArrayInitializer.
*/
template <typename A>
requires(is_regular_or_view_v<A> and std::decay_t<A>::is_stride_order_C())
auto lazy_mpi_gather(A &&a, mpi::communicator comm = {}, int root = 0, bool all = false) {
return mpi::lazy<mpi::tag::gather, A>{std::forward<A>(a), comm, root, all};
}

/**
* @brief Implementation of an MPI gather for nda::basic_array or nda::basic_array_view types.
*
* @details The function gathers C-ordered input arrays/views from all processes in the given communicator and
* makes the result available on the root process (`all == false`) or on all processes (`all == true`). The
* arrays/views are joined along the first dimension.
*
* It is expected that all input arrays/views have the same shape on all processes except for the first dimension. The
* function throws an exception, if
* - the input array/view is not contiguous with positive strides or
* - any of the MPI calls fails.
*
* The input arrays/views are simply concatenated along their first dimension. The shape of the resulting array
* depends on the MPI rank and whether it receives the data or not:
* - On receiving ranks, the shape is the same as the shape of the input array/view except for the first dimension,
* which is the sum of the extents of all input arrays/views along the first dimension.
* - On non-receiving ranks, the shape is empty, i.e. `(0,0,...,0)`.
*
* @code{.cpp}
* // create an array on all processes
* nda::array<int, 2> A(3, 4);
*
* // ...
* // fill array on each process
* // ...
*
* // gather the arrays on the root process
* auto B = mpi::gather(A);
* @endcode
*
* Here, the array `B` has the shape `(3 * comm.size(), 4)` on the root process and `(0, 0)` on all other processes.
*
* @tparam A nda::basic_array or nda::basic_array_view type with C-layout.
* @param a Array/view to be gathered.
* @param comm `mpi::communicator` object.
* @param root Rank of the root process.
* @param all Should all processes receive the result of the gather.
* @return An nda::basic_array object with the result of the gathering.
*/
template <typename A>
requires(is_regular_or_view_v<A> and std::decay_t<A>::is_stride_order_C())
auto mpi_gather(A const &a, mpi::communicator comm = {}, int root = 0, bool all = false) {
using return_t = get_regular_t<A>;
return_t a_out;
detail::mpi_gather_impl(a, a_out, comm, root, all);
return a_out;
}

/** @} */

} // namespace nda
Loading
Loading