Skip to content

Commit

Permalink
review updates:
Browse files Browse the repository at this point in the history
- documentation
- renaming

Co-authored-by: Pratik Nayak <[email protected]>
Co-authored-by: Tobias Ribizel <[email protected]>
  • Loading branch information
3 people committed Apr 30, 2024
1 parent e4a1b6d commit 17c52a5
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 52 deletions.
11 changes: 6 additions & 5 deletions core/distributed/index_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,24 +77,25 @@ size_type index_map<LocalIndexType, GlobalIndexType>::get_global_size() const

template <typename LocalIndexType, typename GlobalIndexType>
array<LocalIndexType> index_map<LocalIndexType, GlobalIndexType>::get_local(
const array<GlobalIndexType>& global_ids, index_space is) const
const array<GlobalIndexType>& global_ids, index_space index_space_v) const
{
array<LocalIndexType> local_ids(exec_);

exec_->run(index_map_kernels::make_get_local(
partition_.get(), remote_target_ids_, remote_global_idxs_, rank_,
global_ids, is, local_ids));
global_ids, index_space_v, local_ids));

return local_ids;
}


template <typename LocalIndexType, typename GlobalIndexType>
index_map<LocalIndexType, GlobalIndexType>::index_map(
std::shared_ptr<const Executor> exec, std::shared_ptr<const part_type> part,
comm_index_type rank, const array<GlobalIndexType>& recv_connections)
std::shared_ptr<const Executor> exec,
std::shared_ptr<const partition_type> partition, comm_index_type rank,
const array<GlobalIndexType>& recv_connections)
: exec_(std::move(exec)),
partition_(std::move(part)),
partition_(std::move(partition)),
rank_(rank),
remote_target_ids_(exec_),
remote_local_idxs_(exec_),
Expand Down
12 changes: 7 additions & 5 deletions core/distributed/index_map_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
//
// SPDX-License-Identifier: BSD-3-Clause

#ifndef INDEX_MAP_KERNELS_HPP
#define INDEX_MAP_KERNELS_HPP
#ifndef GKO_CORE_DISTRIBUTED_INDEX_MAP_KERNELS_HPP_
#define GKO_CORE_DISTRIBUTED_INDEX_MAP_KERNELS_HPP_


#include <ginkgo/core/distributed/index_map.hpp>


#include <ginkgo/core/base/array.hpp>
#include <ginkgo/core/base/collection.hpp>
#include <ginkgo/core/distributed/index_map.hpp>
#include <ginkgo/core/distributed/partition.hpp>


Expand All @@ -34,7 +36,7 @@ namespace kernels {
std::shared_ptr<const DefaultExecutor> exec, \
const experimental::distributed::Partition<_ltype, _gtype>* partition, \
const array<experimental::distributed::comm_index_type>& \
remote_targed_ids, \
remote_target_ids, \
const collection::array<_gtype>& remote_global_idxs, \
experimental::distributed::comm_index_type rank, \
const array<_gtype>& global_ids, \
Expand All @@ -59,4 +61,4 @@ GKO_DECLARE_FOR_ALL_EXECUTOR_NAMESPACES(index_map,
} // namespace kernels
} // namespace gko

#endif // INDEX_MAP_KERNELS_HPP
#endif // GKO_CORE_DISTRIBUTED_INDEX_MAP_KERNELS_HPP_
18 changes: 10 additions & 8 deletions include/ginkgo/core/distributed/index_map.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,17 @@ enum class index_space {
*/
template <typename LocalIndexType, typename GlobalIndexType = int64>
struct index_map {
using part_type = Partition<LocalIndexType, GlobalIndexType>;
using partition_type = Partition<LocalIndexType, GlobalIndexType>;

/**
* \brief Maps global indices to local indices
*
* \param global_ids the global indices to map
* \param is the index space in which the returned local indices are defined
* \param global_ids the global indices to map
* \param index_space_v the index space in which the returned local indices
* are defined
*
* \return the mapped local indices. Any global index that is not in the
* specified index space is mapped to invalid_index.
* \return the mapped local indices. Any global index that is not in the
* specified index space is mapped to invalid_index.
*/
array<LocalIndexType> get_local(
const array<GlobalIndexType>& global_ids,
Expand Down Expand Up @@ -98,13 +99,14 @@ struct index_map {
* filtered out.
*
* \param exec the executor
* \param part the partition of the global index set
* \param partition the partition of the global index set
* \param rank the id of the global index space subset
* \param recv_connections the global indices that are not owned by this
* rank, but accessed by it
*/
index_map(std::shared_ptr<const Executor> exec,
std::shared_ptr<const part_type> part, comm_index_type rank,
std::shared_ptr<const partition_type> partition,
comm_index_type rank,
const array<GlobalIndexType>& recv_connections);

/**
Expand Down Expand Up @@ -157,7 +159,7 @@ struct index_map {

private:
std::shared_ptr<const Executor> exec_;
std::shared_ptr<const part_type> partition_;
std::shared_ptr<const partition_type> partition_;
comm_index_type rank_;

array<comm_index_type> remote_target_ids_;
Expand Down
35 changes: 6 additions & 29 deletions reference/distributed/index_map_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,35 +28,12 @@ void build_mapping(
collection::array<GlobalIndexType>& remote_global_idxs)
{
using experimental::distributed::comm_index_type;
using partition_type =
experimental::distributed::Partition<LocalIndexType, GlobalIndexType>;
auto part_ids = part->get_part_ids();

std::vector<GlobalIndexType> unique_indices(recv_connections.get_size());
std::copy_n(recv_connections.get_const_data(), recv_connections.get_size(),
unique_indices.begin());

auto find_range = [](GlobalIndexType idx, const partition_type* partition,
size_type hint) {
auto range_bounds = partition->get_range_bounds();
auto num_ranges = partition->get_num_ranges();
if (range_bounds[hint] <= idx && idx < range_bounds[hint + 1]) {
return hint;
} else {
auto it = std::upper_bound(range_bounds + 1,
range_bounds + num_ranges + 1, idx);
return static_cast<size_type>(std::distance(range_bounds + 1, it));
}
};

auto map_to_local = [](GlobalIndexType idx, const partition_type* partition,
size_type range_id) {
auto range_bounds = partition->get_range_bounds();
auto range_starting_indices = partition->get_range_starting_indices();
return static_cast<LocalIndexType>(idx - range_bounds[range_id]) +
range_starting_indices[range_id];
};

auto find_part = [&](GlobalIndexType idx) {
auto range_id = find_range(idx, part, 0);
return part_ids[range_id];
Expand Down Expand Up @@ -139,7 +116,7 @@ void get_local(
std::shared_ptr<const DefaultExecutor> exec,
const experimental::distributed::Partition<LocalIndexType, GlobalIndexType>*
partition,
const array<experimental::distributed::comm_index_type>& remote_targed_ids,
const array<experimental::distributed::comm_index_type>& remote_target_ids,
const collection::array<GlobalIndexType>& remote_global_idxs,
experimental::distributed::comm_index_type rank,
const array<GlobalIndexType>& global_ids,
Expand Down Expand Up @@ -171,13 +148,13 @@ void get_local(
// the global indexing. So find the part-id that corresponds
// to the global index first
auto set_id = std::distance(
remote_targed_ids.get_const_data(),
std::lower_bound(remote_targed_ids.get_const_data(),
remote_targed_ids.get_const_data() +
remote_targed_ids.get_size(),
remote_target_ids.get_const_data(),
std::lower_bound(remote_target_ids.get_const_data(),
remote_target_ids.get_const_data() +
remote_target_ids.get_size(),
part_id));

if (set_id == remote_targed_ids.get_size()) {
if (set_id == remote_target_ids.get_size()) {
return invalid_index<LocalIndexType>();
}

Expand Down
5 changes: 0 additions & 5 deletions reference/test/distributed/index_map_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
#include "core/distributed/index_map_kernels.hpp"
#include "core/test/utils.hpp"

namespace {


using comm_index_type = gko::experimental::distributed::comm_index_type;

Expand Down Expand Up @@ -164,6 +162,3 @@ TEST_F(IndexMap, CanGetLocalWithCombinedISWithInvalid)
gko::array<local_index_type> expected(ref, {2, 3, 0, 1, 2, 4, -1, 1});
GKO_ASSERT_ARRAY_EQ(local_ids, expected);
}


} // namespace

0 comments on commit 17c52a5

Please sign in to comment.