diff --git a/core/distributed/matrix.cpp b/core/distributed/matrix.cpp index d169d7bbdac..f5f6d32b8ea 100644 --- a/core/distributed/matrix.cpp +++ b/core/distributed/matrix.cpp @@ -6,6 +6,7 @@ #include +#include #include #include @@ -45,14 +46,11 @@ Matrix::Matrix( : EnableDistributedLinOp< Matrix>{exec}, DistributedBase{comm}, - send_offsets_(comm.size() + 1), - send_sizes_(comm.size()), - recv_offsets_(comm.size() + 1), - recv_sizes_(comm.size()), - gather_idxs_{exec}, one_scalar_{}, local_mtx_{local_matrix_template->clone(exec)}, - non_local_mtx_{non_local_matrix_template->clone(exec)} + non_local_mtx_{non_local_matrix_template->clone(exec)}, + row_gatherer_{RowGatherer::create(exec, comm)}, + imap_{exec} { GKO_ASSERT( (dynamic_cast*>( @@ -105,11 +103,7 @@ void Matrix::convert_to( result->get_communicator().size()); result->local_mtx_->copy_from(this->local_mtx_); result->non_local_mtx_->copy_from(this->non_local_mtx_); - result->gather_idxs_ = this->gather_idxs_; - result->send_offsets_ = this->send_offsets_; - result->recv_offsets_ = this->recv_offsets_; - result->recv_sizes_ = this->recv_sizes_; - result->send_sizes_ = this->send_sizes_; + result->row_gatherer_->copy_from(this->row_gatherer_); result->set_size(this->get_size()); } @@ -123,11 +117,7 @@ void Matrix::move_to( result->get_communicator().size()); result->local_mtx_->move_from(this->local_mtx_); result->non_local_mtx_->move_from(this->non_local_mtx_); - result->gather_idxs_ = std::move(this->gather_idxs_); - result->send_offsets_ = std::move(this->send_offsets_); - result->recv_offsets_ = std::move(this->recv_offsets_); - result->recv_sizes_ = std::move(this->recv_sizes_); - result->send_sizes_ = std::move(this->send_sizes_); + result->row_gatherer_->move_from(this->row_gatherer_); result->set_size(this->get_size()); this->set_size({}); } @@ -151,7 +141,6 @@ Matrix::read_distributed( auto local_part = comm.rank(); // set up LinOp sizes - auto num_parts = static_cast(row_partition->get_num_parts()); auto global_num_rows = row_partition->get_size(); auto global_num_cols = col_partition->get_size(); dim<2> global_dim{global_num_rows, global_num_cols}; @@ -175,11 +164,11 @@ Matrix::read_distributed( local_row_idxs, local_col_idxs, local_values, non_local_row_idxs, global_non_local_col_idxs, non_local_values)); - auto imap = index_map( + imap_ = index_map( exec, col_partition, comm.rank(), global_non_local_col_idxs); auto non_local_col_idxs = - imap.get_local(global_non_local_col_idxs, index_space::non_local); + imap_.get_local(global_non_local_col_idxs, index_space::non_local); // read the local matrix data const auto num_local_rows = @@ -192,7 +181,7 @@ Matrix::read_distributed( device_matrix_data non_local_data{ exec, dim<2>{num_local_rows, - imap.get_remote_global_idxs().get_flat().get_size()}, + imap_.get_remote_global_idxs().get_flat().get_size()}, std::move(non_local_row_idxs), std::move(non_local_col_idxs), std::move(non_local_values)}; as>(this->local_mtx_) @@ -200,40 +189,11 @@ Matrix::read_distributed( as>(this->non_local_mtx_) ->read(std::move(non_local_data)); - // exchange step 1: determine recv_sizes, send_sizes, send_offsets - auto host_recv_targets = - make_temporary_clone(exec->get_master(), &imap.get_remote_target_ids()); - std::fill(recv_sizes_.begin(), recv_sizes_.end(), 0); - for (size_type i = 0; i < host_recv_targets->get_size(); ++i) { - recv_sizes_[host_recv_targets->get_const_data()[i]] = - imap.get_remote_global_idxs()[i].get_size(); - } - std::partial_sum(recv_sizes_.begin(), recv_sizes_.end(), - recv_offsets_.begin() + 1); - comm.all_to_all(exec, recv_sizes_.data(), 1, send_sizes_.data(), 1); - std::partial_sum(send_sizes_.begin(), send_sizes_.end(), - send_offsets_.begin() + 1); - send_offsets_[0] = 0; - recv_offsets_[0] = 0; - - // exchange step 2: exchange gather_idxs from receivers to senders - auto recv_gather_idxs = imap.get_remote_local_idxs().get_flat(); - auto use_host_buffer = mpi::requires_host_buffer(exec, comm); - if (use_host_buffer) { - recv_gather_idxs.set_executor(exec->get_master()); - gather_idxs_.clear(); - gather_idxs_.set_executor(exec->get_master()); - } - gather_idxs_.resize_and_reset(send_offsets_.back()); - comm.all_to_all_v(use_host_buffer ? exec->get_master() : exec, - recv_gather_idxs.get_const_data(), recv_sizes_.data(), - recv_offsets_.data(), gather_idxs_.get_data(), - send_sizes_.data(), send_offsets_.data()); - if (use_host_buffer) { - gather_idxs_.set_executor(exec); - } + row_gatherer_ = RowGatherer::create( + exec, std::make_shared(comm, imap_), + imap_); - return imap; + return imap_; } @@ -278,50 +238,6 @@ Matrix::read_distributed( } -template -mpi::request Matrix::communicate( - const local_vector_type* local_b) const -{ - auto exec = this->get_executor(); - const auto comm = this->get_communicator(); - auto num_cols = local_b->get_size()[1]; - auto send_size = send_offsets_.back(); - auto recv_size = recv_offsets_.back(); - auto send_dim = dim<2>{static_cast(send_size), num_cols}; - auto recv_dim = dim<2>{static_cast(recv_size), num_cols}; - recv_buffer_.init(exec, recv_dim); - send_buffer_.init(exec, send_dim); - - local_b->row_gather(&gather_idxs_, send_buffer_.get()); - - auto use_host_buffer = mpi::requires_host_buffer(exec, comm); - if (use_host_buffer) { - host_recv_buffer_.init(exec->get_master(), recv_dim); - host_send_buffer_.init(exec->get_master(), send_dim); - host_send_buffer_->copy_from(send_buffer_.get()); - } - - mpi::contiguous_type type(num_cols, mpi::type_impl::get_type()); - auto send_ptr = use_host_buffer ? host_send_buffer_->get_const_values() - : send_buffer_->get_const_values(); - auto recv_ptr = use_host_buffer ? host_recv_buffer_->get_values() - : recv_buffer_->get_values(); - exec->synchronize(); -#ifdef GINKGO_FORCE_SPMV_BLOCKING_COMM - comm.all_to_all_v(use_host_buffer ? exec->get_master() : exec, send_ptr, - send_sizes_.data(), send_offsets_.data(), type.get(), - recv_ptr, recv_sizes_.data(), recv_offsets_.data(), - type.get()); - return {}; -#else - return comm.i_all_to_all_v( - use_host_buffer ? exec->get_master() : exec, send_ptr, - send_sizes_.data(), send_offsets_.data(), type.get(), recv_ptr, - recv_sizes_.data(), recv_offsets_.data(), type.get()); -#endif -} - - template void Matrix::apply_impl( const LinOp* b, LinOp* x) const @@ -337,16 +253,16 @@ void Matrix::apply_impl( dense_x->get_local_values()), dense_x->get_local_vector()->get_stride()); + auto exec = this->get_executor(); auto comm = this->get_communicator(); - auto req = this->communicate(dense_b->get_local_vector()); + auto recv_dim = + dim<2>{imap_.get_non_local_size(), dense_b->get_size()[1]}; + recv_buffer_.init(exec, recv_dim); + auto req = + this->row_gatherer_->apply_async(dense_b, recv_buffer_.get()); local_mtx_->apply(dense_b->get_local_vector(), local_x); req.wait(); - auto exec = this->get_executor(); - auto use_host_buffer = mpi::requires_host_buffer(exec, comm); - if (use_host_buffer) { - recv_buffer_->copy_from(host_recv_buffer_.get()); - } non_local_mtx_->apply(one_scalar_.get(), recv_buffer_.get(), one_scalar_.get(), local_x); }, @@ -370,17 +286,17 @@ void Matrix::apply_impl( dense_x->get_local_values()), dense_x->get_local_vector()->get_stride()); + auto exec = this->get_executor(); auto comm = this->get_communicator(); - auto req = this->communicate(dense_b->get_local_vector()); + auto recv_dim = + dim<2>{imap_.get_non_local_size(), dense_b->get_size()[1]}; + recv_buffer_.init(exec, recv_dim); + auto req = + this->row_gatherer_->apply_async(dense_b, recv_buffer_.get()); local_mtx_->apply(local_alpha, dense_b->get_local_vector(), local_beta, local_x); req.wait(); - auto exec = this->get_executor(); - auto use_host_buffer = mpi::requires_host_buffer(exec, comm); - if (use_host_buffer) { - recv_buffer_->copy_from(host_recv_buffer_.get()); - } non_local_mtx_->apply(local_alpha, recv_buffer_.get(), one_scalar_.get(), local_x); }, @@ -392,7 +308,10 @@ template Matrix::Matrix(const Matrix& other) : EnableDistributedLinOp>{other.get_executor()}, - DistributedBase{other.get_communicator()} + DistributedBase{other.get_communicator()}, + row_gatherer_{RowGatherer::create( + other.get_executor(), other.get_communicator())}, + imap_{other.get_executor()} { *this = other; } @@ -403,7 +322,10 @@ Matrix::Matrix( Matrix&& other) noexcept : EnableDistributedLinOp>{other.get_executor()}, - DistributedBase{other.get_communicator()} + DistributedBase{other.get_communicator()}, + row_gatherer_{RowGatherer::create( + other.get_executor(), other.get_communicator())}, + imap_{other.get_executor()} { *this = std::move(other); } @@ -420,11 +342,7 @@ Matrix::operator=( this->set_size(other.get_size()); local_mtx_->copy_from(other.local_mtx_); non_local_mtx_->copy_from(other.non_local_mtx_); - gather_idxs_ = other.gather_idxs_; - send_offsets_ = other.send_offsets_; - recv_offsets_ = other.recv_offsets_; - send_sizes_ = other.send_sizes_; - recv_sizes_ = other.recv_sizes_; + row_gatherer_->copy_from(other.row_gatherer_); one_scalar_.init(this->get_executor(), dim<2>{1, 1}); one_scalar_->fill(one()); } @@ -443,11 +361,7 @@ Matrix::operator=(Matrix&& other) other.set_size({}); local_mtx_->move_from(other.local_mtx_); non_local_mtx_->move_from(other.non_local_mtx_); - gather_idxs_ = std::move(other.gather_idxs_); - send_offsets_ = std::move(other.send_offsets_); - recv_offsets_ = std::move(other.recv_offsets_); - send_sizes_ = std::move(other.send_sizes_); - recv_sizes_ = std::move(other.recv_sizes_); + row_gatherer_->move_from(other.row_gatherer_); one_scalar_.init(this->get_executor(), dim<2>{1, 1}); one_scalar_->fill(one()); } diff --git a/include/ginkgo/core/distributed/matrix.hpp b/include/ginkgo/core/distributed/matrix.hpp index a0476534b71..5b939cefee1 100644 --- a/include/ginkgo/core/distributed/matrix.hpp +++ b/include/ginkgo/core/distributed/matrix.hpp @@ -17,6 +17,7 @@ #include #include #include +#include namespace gko { @@ -358,6 +359,17 @@ class Matrix return non_local_mtx_; } + std::shared_ptr> get_row_gatherer() + const + { + return row_gatherer_; + } + + const index_map& get_index_map() const + { + return imap_; + } + /** * Copy constructs a Matrix. * @@ -530,31 +542,15 @@ class Matrix ptr_param local_matrix_template, ptr_param non_local_matrix_template); - /** - * Starts a non-blocking communication of the values of b that are shared - * with other processors. - * - * @param local_b The full local vector to be communicated. The subset of - * shared values is automatically extracted. - * @return MPI request for the non-blocking communication. - */ - mpi::request communicate(const local_vector_type* local_b) const; - void apply_impl(const LinOp* b, LinOp* x) const override; void apply_impl(const LinOp* alpha, const LinOp* b, const LinOp* beta, LinOp* x) const override; private: - std::vector send_offsets_; - std::vector send_sizes_; - std::vector recv_offsets_; - std::vector recv_sizes_; - array gather_idxs_; + std::shared_ptr> row_gatherer_; + index_map imap_; gko::detail::DenseCache one_scalar_; - gko::detail::DenseCache host_send_buffer_; - gko::detail::DenseCache host_recv_buffer_; - gko::detail::DenseCache send_buffer_; gko::detail::DenseCache recv_buffer_; std::shared_ptr local_mtx_; std::shared_ptr non_local_mtx_;