From 6fd0e7ff1d829917358c3e974ee8239358903938 Mon Sep 17 00:00:00 2001 From: Jiakun Yan Date: Mon, 18 Nov 2024 16:20:32 -0500 Subject: [PATCH] Fix messages larger than INT_MAX for mpi --- .../include/hpx/mpi_base/mpi_environment.hpp | 5 + libs/core/mpi_base/src/mpi_environment.cpp | 92 +++++++++++++++++++ .../parcelport_mpi/receiver_connection.hpp | 68 ++++---------- .../hpx/parcelport_mpi/sender_connection.hpp | 69 ++++---------- 4 files changed, 130 insertions(+), 104 deletions(-) diff --git a/libs/core/mpi_base/include/hpx/mpi_base/mpi_environment.hpp b/libs/core/mpi_base/include/hpx/mpi_base/mpi_environment.hpp index dde6149131b..384b993b35b 100644 --- a/libs/core/mpi_base/include/hpx/mpi_base/mpi_environment.hpp +++ b/libs/core/mpi_base/include/hpx/mpi_base/mpi_environment.hpp @@ -1,4 +1,5 @@ // Copyright (c) 2013-2015 Thomas Heller +// Copyright (c) 2024 Jiakun Yan // // SPDX-License-Identifier: BSL-1.0 // Distributed under the Boost Software License, Version 1.0. (See accompanying @@ -42,6 +43,10 @@ namespace hpx::util { static std::string get_processor_name(); + static MPI_Datatype type_contiguous(size_t nbytes); + static MPI_Request isend(void* address, size_t size, int rank, int tag); + static MPI_Request irecv(void* address, size_t size, int rank, int tag); + struct HPX_CORE_EXPORT scoped_lock { scoped_lock(); diff --git a/libs/core/mpi_base/src/mpi_environment.cpp b/libs/core/mpi_base/src/mpi_environment.cpp index a155bc262e4..a251fc57993 100644 --- a/libs/core/mpi_base/src/mpi_environment.cpp +++ b/libs/core/mpi_base/src/mpi_environment.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2020 Google // Copyright (c) 2022 Patrick Diehl // Copyright (c) 2023 Hartmut Kaiser +// Copyright (c) 2024 Jiakun Yan // // SPDX-License-Identifier: BSL-1.0 // Distributed under the Boost Software License, Version 1.0. (See accompanying @@ -467,6 +468,97 @@ namespace hpx::util { report_error(sl, error_code); } + + // Acknowledgement: code adapted from github.com/jeffhammond/BigMPI + MPI_Datatype mpi_environment::type_contiguous(size_t nbytes) + { + size_t int_max = (std::numeric_limits::max)(); + + size_t c = nbytes / int_max; + size_t r = nbytes % int_max; + + HPX_ASSERT(c < int_max); + HPX_ASSERT(r < int_max); + + MPI_Datatype chunks; + MPI_Type_vector(c, int_max, int_max, MPI_BYTE, &chunks); + + MPI_Datatype remainder; + MPI_Type_contiguous(r, MPI_BYTE, &remainder); + + MPI_Aint remdisp = (MPI_Aint) c * int_max; + int blocklengths[2] = {1, 1}; + MPI_Aint displacements[2] = {0, remdisp}; + MPI_Datatype types[2] = {chunks, remainder}; + MPI_Datatype newtype; + MPI_Type_create_struct(2, blocklengths, displacements, types, &newtype); + + MPI_Type_free(&chunks); + MPI_Type_free(&remainder); + + return newtype; + } + + MPI_Request mpi_environment::isend( + void* address, size_t size, int rank, int tag) + { + MPI_Request request; + MPI_Datatype datatype; + int length; + if (size > static_cast((std::numeric_limits::max)())) + { + datatype = type_contiguous(size); + MPI_Type_commit(&datatype); + length = 1; + } + else + { + datatype = MPI_BYTE; + length = static_cast(size); + } + + { + scoped_lock l; + int const ret = MPI_Isend( + address, length, datatype, rank, tag, communicator(), &request); + check_mpi_error(l, HPX_CURRENT_SOURCE_LOCATION(), ret); + } + + if (datatype != MPI_BYTE) + MPI_Type_free(&datatype); + return request; + } + + MPI_Request mpi_environment::irecv( + void* address, size_t size, int rank, int tag) + { + MPI_Request request; + MPI_Datatype datatype; + int length; + if (size > static_cast((std::numeric_limits::max)())) + { + datatype = type_contiguous(size); + MPI_Type_commit(&datatype); + length = 1; + } + else + { + datatype = MPI_BYTE; + length = static_cast(size); + } + + { + scoped_lock l; + int const ret = MPI_Irecv( + address, length, datatype, rank, tag, communicator(), &request); + check_mpi_error(l, HPX_CURRENT_SOURCE_LOCATION(), ret); + } + + if (datatype != MPI_BYTE) + MPI_Type_free(&datatype); + + return request; + } } // namespace hpx::util #endif diff --git a/libs/full/parcelport_mpi/include/hpx/parcelport_mpi/receiver_connection.hpp b/libs/full/parcelport_mpi/include/hpx/parcelport_mpi/receiver_connection.hpp index bed4e218af3..eba1305b575 100644 --- a/libs/full/parcelport_mpi/include/hpx/parcelport_mpi/receiver_connection.hpp +++ b/libs/full/parcelport_mpi/include/hpx/parcelport_mpi/receiver_connection.hpp @@ -1,6 +1,6 @@ // Copyright (c) 2014-2015 Thomas Heller // Copyright (c) 2007-2024 Hartmut Kaiser -// Copyright (c) 2023 Jiakun Yan +// Copyright (c) 2023-2024 Jiakun Yan // // SPDX-License-Identifier: BSL-1.0 // Distributed under the Boost Software License, Version 1.0. (See accompanying @@ -163,14 +163,11 @@ namespace hpx::parcelset::policies::mpi { { util::mpi_environment::scoped_lock l; - int const ret = MPI_Irecv(buffer_.transmission_chunks_.data(), - static_cast(buffer_.transmission_chunks_.size() * - sizeof(buffer_type::transmission_chunk_type)), - MPI_BYTE, src_, tag_, util::mpi_environment::communicator(), - &request_); - util::mpi_environment::check_mpi_error( - l, HPX_CURRENT_SOURCE_LOCATION(), ret); - + request_ = util::mpi_environment::irecv( + buffer_.transmission_chunks_.data(), + buffer_.transmission_chunks_.size() * + sizeof(buffer_type::transmission_chunk_type), + src_, tag_); request_ptr_ = &request_; state_ = connection_state::rcvd_transmission_chunks; @@ -207,12 +204,8 @@ namespace hpx::parcelset::policies::mpi { ack_ = static_cast( connection_state::acked_transmission_chunks); - int const ret = - MPI_Isend(&ack_, sizeof(ack_), MPI_BYTE, src_, ack_tag(), - util::mpi_environment::communicator(), &request_); - util::mpi_environment::check_mpi_error( - l, HPX_CURRENT_SOURCE_LOCATION(), ret); - + request_ = util::mpi_environment::isend( + &ack_, sizeof(ack_), src_, ack_tag()); request_ptr_ = &request_; } @@ -241,14 +234,8 @@ namespace hpx::parcelset::policies::mpi { if (need_recv_data) { - util::mpi_environment::scoped_lock l; - - int const ret = MPI_Irecv(buffer_.data_.data(), - static_cast(buffer_.data_.size()), MPI_BYTE, src_, - tag_, util::mpi_environment::communicator(), &request_); - util::mpi_environment::check_mpi_error( - l, HPX_CURRENT_SOURCE_LOCATION(), ret); - + request_ = util::mpi_environment::irecv( + buffer_.data_.data(), buffer_.data_.size(), src_, tag_); request_ptr_ = &request_; state_ = connection_state::rcvd_data; @@ -276,15 +263,8 @@ namespace hpx::parcelset::policies::mpi { HPX_ASSERT(request_ptr_ == nullptr); { - util::mpi_environment::scoped_lock l; - - ack_ = static_cast(connection_state::acked_data); - int const ret = - MPI_Isend(&ack_, sizeof(ack_), MPI_BYTE, src_, ack_tag(), - util::mpi_environment::communicator(), &request_); - util::mpi_environment::check_mpi_error( - l, HPX_CURRENT_SOURCE_LOCATION(), ret); - + request_ = util::mpi_environment::isend( + &ack_, sizeof(ack_), src_, ack_tag()); request_ptr_ = &request_; } @@ -372,17 +352,9 @@ namespace hpx::parcelset::policies::mpi { "zero-copy chunk buffers should have been initialized " "during de-serialization"); - { - util::mpi_environment::scoped_lock l; - - int const ret = MPI_Irecv(c.data(), - static_cast(chunk_size), MPI_BYTE, src_, tag_, - util::mpi_environment::communicator(), &request_); - util::mpi_environment::check_mpi_error( - l, HPX_CURRENT_SOURCE_LOCATION(), ret); - - request_ptr_ = &request_; - } + request_ = util::mpi_environment::irecv( + c.data(), c.size(), src_, tag_); + request_ptr_ = &request_; } HPX_ASSERT_MSG( zero_copy_chunks_idx_ == buffer_.num_chunks_.first, @@ -412,14 +384,8 @@ namespace hpx::parcelset::policies::mpi { c.data(), chunk_size); { - util::mpi_environment::scoped_lock l; - - int const ret = MPI_Irecv(c.data(), - static_cast(c.size()), MPI_BYTE, src_, tag_, - util::mpi_environment::communicator(), &request_); - util::mpi_environment::check_mpi_error( - l, HPX_CURRENT_SOURCE_LOCATION(), ret); - + request_ = util::mpi_environment::irecv( + c.data(), c.size(), src_, tag_); request_ptr_ = &request_; } } diff --git a/libs/full/parcelport_mpi/include/hpx/parcelport_mpi/sender_connection.hpp b/libs/full/parcelport_mpi/include/hpx/parcelport_mpi/sender_connection.hpp index 7b80fec4390..9eb969d3708 100644 --- a/libs/full/parcelport_mpi/include/hpx/parcelport_mpi/sender_connection.hpp +++ b/libs/full/parcelport_mpi/include/hpx/parcelport_mpi/sender_connection.hpp @@ -1,6 +1,6 @@ // Copyright (c) 2007-2024 Hartmut Kaiser // Copyright (c) 2014-2015 Thomas Heller -// Copyright (c) 2023 Jiakun Yan +// Copyright (c) 2023-2024 Jiakun Yan // // SPDX-License-Identifier: BSL-1.0 // Distributed under the Boost Software License, Version 1.0. (See accompanying @@ -177,17 +177,9 @@ namespace hpx::parcelset::policies::mpi { HPX_ASSERT(state_ == connection_state::initialized); HPX_ASSERT(request_ptr_ == nullptr); - { - util::mpi_environment::scoped_lock l; - - int const ret = MPI_Isend(header_buffer.data(), - static_cast(header_buffer.size()), MPI_BYTE, dst_, 0, - util::mpi_environment::communicator(), &request_); - util::mpi_environment::check_mpi_error( - l, HPX_CURRENT_SOURCE_LOCATION(), ret); - - request_ptr_ = &request_; - } + request_ = util::mpi_environment::isend( + header_buffer.data(), header_buffer.size(), dst_, 0); + request_ptr_ = &request_; state_ = connection_state::sent_header; return send_transmission_chunks(); @@ -206,16 +198,10 @@ namespace hpx::parcelset::policies::mpi { auto const& chunks = buffer_.transmission_chunks_; if (!chunks.empty() && !header_.piggy_back_tchunk()) { - util::mpi_environment::scoped_lock l; - - int const ret = MPI_Isend(chunks.data(), - static_cast(chunks.size() * - sizeof(parcel_buffer_type::transmission_chunk_type)), - MPI_BYTE, dst_, tag_, util::mpi_environment::communicator(), - &request_); - util::mpi_environment::check_mpi_error( - l, HPX_CURRENT_SOURCE_LOCATION(), ret); - + request_ = util::mpi_environment::isend( + const_cast( + reinterpret_cast(chunks.data())), + chunks.size(), dst_, tag_); request_ptr_ = &request_; state_ = connection_state::sent_transmission_chunks; @@ -250,14 +236,8 @@ namespace hpx::parcelset::policies::mpi { HPX_ASSERT(request_ptr_ == nullptr); { - util::mpi_environment::scoped_lock l; - - int const ret = - MPI_Irecv(&ack_, sizeof(ack_), MPI_BYTE, dst_, ack_tag(), - util::mpi_environment::communicator(), &request_); - util::mpi_environment::check_mpi_error( - l, HPX_CURRENT_SOURCE_LOCATION(), ret); - + request_ = util::mpi_environment::irecv( + &ack_, sizeof(ack_), dst_, ack_tag()); request_ptr_ = &request_; } @@ -283,12 +263,8 @@ namespace hpx::parcelset::policies::mpi { { util::mpi_environment::scoped_lock l; - int const ret = MPI_Isend(buffer_.data_.data(), - static_cast(buffer_.data_.size()), MPI_BYTE, dst_, - tag_, util::mpi_environment::communicator(), &request_); - util::mpi_environment::check_mpi_error( - l, HPX_CURRENT_SOURCE_LOCATION(), ret); - + request_ = util::mpi_environment::isend( + buffer_.data_.data(), buffer_.data_.size(), dst_, tag_); request_ptr_ = &request_; state_ = connection_state::sent_data; @@ -321,14 +297,8 @@ namespace hpx::parcelset::policies::mpi { HPX_ASSERT(request_ptr_ == nullptr); { - util::mpi_environment::scoped_lock l; - - int const ret = - MPI_Irecv(&ack_, sizeof(ack_), MPI_BYTE, dst_, ack_tag(), - util::mpi_environment::communicator(), &request_); - util::mpi_environment::check_mpi_error( - l, HPX_CURRENT_SOURCE_LOCATION(), ret); - + request_ = util::mpi_environment::irecv( + &ack_, sizeof(ack_), dst_, ack_tag()); request_ptr_ = &request_; } @@ -352,15 +322,8 @@ namespace hpx::parcelset::policies::mpi { return false; } HPX_ASSERT(request_ptr_ == nullptr); - - util::mpi_environment::scoped_lock l; - - int const ret = MPI_Isend(c.data_.cpos_, - static_cast(c.size_), MPI_BYTE, dst_, tag_, - util::mpi_environment::communicator(), &request_); - util::mpi_environment::check_mpi_error( - l, HPX_CURRENT_SOURCE_LOCATION(), ret); - + request_ = util::mpi_environment::isend( + const_cast(c.data()), c.size(), dst_, tag_); request_ptr_ = &request_; }