diff --git a/core/distributed/index_map.cpp b/core/distributed/index_map.cpp index 7e497f15b34..c09bbdd13b2 100644 --- a/core/distributed/index_map.cpp +++ b/core/distributed/index_map.cpp @@ -77,13 +77,13 @@ size_type index_map::get_global_size() const template array index_map::get_local( - const array& global_ids, index_space is) const + const array& global_ids, index_space index_space_v) const { array local_ids(exec_); exec_->run(index_map_kernels::make_get_local( partition_.get(), remote_target_ids_, remote_global_idxs_, rank_, - global_ids, is, local_ids)); + global_ids, index_space_v, local_ids)); return local_ids; } @@ -91,10 +91,11 @@ array index_map::get_local( template index_map::index_map( - std::shared_ptr exec, std::shared_ptr part, - comm_index_type rank, const array& recv_connections) + std::shared_ptr exec, + std::shared_ptr partition, comm_index_type rank, + const array& recv_connections) : exec_(std::move(exec)), - partition_(std::move(part)), + partition_(std::move(partition)), rank_(rank), remote_target_ids_(exec_), remote_local_idxs_(exec_), diff --git a/core/distributed/index_map_kernels.hpp b/core/distributed/index_map_kernels.hpp index bf93acd963a..ff9e5fc800c 100644 --- a/core/distributed/index_map_kernels.hpp +++ b/core/distributed/index_map_kernels.hpp @@ -2,13 +2,15 @@ // // SPDX-License-Identifier: BSD-3-Clause -#ifndef INDEX_MAP_KERNELS_HPP -#define INDEX_MAP_KERNELS_HPP +#ifndef GKO_CORE_DISTRIBUTED_INDEX_MAP_KERNELS_HPP_ +#define GKO_CORE_DISTRIBUTED_INDEX_MAP_KERNELS_HPP_ + + +#include #include #include -#include #include @@ -34,7 +36,7 @@ namespace kernels { std::shared_ptr exec, \ const experimental::distributed::Partition<_ltype, _gtype>* partition, \ const array& \ - remote_targed_ids, \ + remote_target_ids, \ const collection::array<_gtype>& remote_global_idxs, \ experimental::distributed::comm_index_type rank, \ const array<_gtype>& global_ids, \ @@ -59,4 +61,4 @@ GKO_DECLARE_FOR_ALL_EXECUTOR_NAMESPACES(index_map, } // namespace kernels } // namespace gko -#endif // INDEX_MAP_KERNELS_HPP +#endif // GKO_CORE_DISTRIBUTED_INDEX_MAP_KERNELS_HPP_ diff --git a/include/ginkgo/core/distributed/index_map.hpp b/include/ginkgo/core/distributed/index_map.hpp index 99eba9d459c..222141749e0 100644 --- a/include/ginkgo/core/distributed/index_map.hpp +++ b/include/ginkgo/core/distributed/index_map.hpp @@ -61,16 +61,17 @@ enum class index_space { */ template struct index_map { - using part_type = Partition; + using partition_type = Partition; /** * \brief Maps global indices to local indices * - * \param global_ids the global indices to map - * \param is the index space in which the returned local indices are defined + * \param global_ids the global indices to map + * \param index_space_v the index space in which the returned local indices + * are defined * - * \return the mapped local indices. Any global index that is not in the - * specified index space is mapped to invalid_index. + * \return the mapped local indices. Any global index that is not in the + * specified index space is mapped to invalid_index. */ array get_local( const array& global_ids, @@ -98,13 +99,14 @@ struct index_map { * filtered out. * * \param exec the executor - * \param part the partition of the global index set + * \param partition the partition of the global index set * \param rank the id of the global index space subset * \param recv_connections the global indices that are not owned by this * rank, but accessed by it */ index_map(std::shared_ptr exec, - std::shared_ptr part, comm_index_type rank, + std::shared_ptr partition, + comm_index_type rank, const array& recv_connections); /** @@ -157,7 +159,7 @@ struct index_map { private: std::shared_ptr exec_; - std::shared_ptr partition_; + std::shared_ptr partition_; comm_index_type rank_; array remote_target_ids_; diff --git a/reference/distributed/index_map_kernels.cpp b/reference/distributed/index_map_kernels.cpp index f1c9b8c5f7d..b0ff35b5eb9 100644 --- a/reference/distributed/index_map_kernels.cpp +++ b/reference/distributed/index_map_kernels.cpp @@ -28,35 +28,12 @@ void build_mapping( collection::array& remote_global_idxs) { using experimental::distributed::comm_index_type; - using partition_type = - experimental::distributed::Partition; auto part_ids = part->get_part_ids(); std::vector unique_indices(recv_connections.get_size()); std::copy_n(recv_connections.get_const_data(), recv_connections.get_size(), unique_indices.begin()); - auto find_range = [](GlobalIndexType idx, const partition_type* partition, - size_type hint) { - auto range_bounds = partition->get_range_bounds(); - auto num_ranges = partition->get_num_ranges(); - if (range_bounds[hint] <= idx && idx < range_bounds[hint + 1]) { - return hint; - } else { - auto it = std::upper_bound(range_bounds + 1, - range_bounds + num_ranges + 1, idx); - return static_cast(std::distance(range_bounds + 1, it)); - } - }; - - auto map_to_local = [](GlobalIndexType idx, const partition_type* partition, - size_type range_id) { - auto range_bounds = partition->get_range_bounds(); - auto range_starting_indices = partition->get_range_starting_indices(); - return static_cast(idx - range_bounds[range_id]) + - range_starting_indices[range_id]; - }; - auto find_part = [&](GlobalIndexType idx) { auto range_id = find_range(idx, part, 0); return part_ids[range_id]; @@ -139,7 +116,7 @@ void get_local( std::shared_ptr exec, const experimental::distributed::Partition* partition, - const array& remote_targed_ids, + const array& remote_target_ids, const collection::array& remote_global_idxs, experimental::distributed::comm_index_type rank, const array& global_ids, @@ -171,13 +148,13 @@ void get_local( // the global indexing. So find the part-id that corresponds // to the global index first auto set_id = std::distance( - remote_targed_ids.get_const_data(), - std::lower_bound(remote_targed_ids.get_const_data(), - remote_targed_ids.get_const_data() + - remote_targed_ids.get_size(), + remote_target_ids.get_const_data(), + std::lower_bound(remote_target_ids.get_const_data(), + remote_target_ids.get_const_data() + + remote_target_ids.get_size(), part_id)); - if (set_id == remote_targed_ids.get_size()) { + if (set_id == remote_target_ids.get_size()) { return invalid_index(); } diff --git a/reference/test/distributed/index_map_kernels.cpp b/reference/test/distributed/index_map_kernels.cpp index c22389692ce..8b94860ab0f 100644 --- a/reference/test/distributed/index_map_kernels.cpp +++ b/reference/test/distributed/index_map_kernels.cpp @@ -20,8 +20,6 @@ #include "core/distributed/index_map_kernels.hpp" #include "core/test/utils.hpp" -namespace { - using comm_index_type = gko::experimental::distributed::comm_index_type; @@ -164,6 +162,3 @@ TEST_F(IndexMap, CanGetLocalWithCombinedISWithInvalid) gko::array expected(ref, {2, 3, 0, 1, 2, 4, -1, 1}); GKO_ASSERT_ARRAY_EQ(local_ids, expected); } - - -} // namespace