From 52177fd98c1bdb490cfd005ef571c5aa9b466ce7 Mon Sep 17 00:00:00 2001 From: Trevor Steil Date: Thu, 25 Jul 2024 14:12:18 -0700 Subject: [PATCH] Updates array to use CRTP base classes and adds more tests --- include/ygm/container/array.hpp | 201 ++++++++++++--- include/ygm/container/detail/array.ipp | 230 ------------------ .../container/detail/block_partitioner.hpp | 86 +++++++ test/test_array.cpp | 104 +++++++- 4 files changed, 351 insertions(+), 270 deletions(-) delete mode 100644 include/ygm/container/detail/array.ipp create mode 100644 include/ygm/container/detail/block_partitioner.hpp diff --git a/include/ygm/container/array.hpp b/include/ygm/container/array.hpp index 7498bf27..e75067c4 100644 --- a/include/ygm/container/array.hpp +++ b/include/ygm/container/array.hpp @@ -7,37 +7,118 @@ #include #include +#include +#include +#include +#include +#include namespace ygm::container { template -class array { +class array + : public detail::base_async_insert_key_value, + std::tuple>, + public detail::base_misc, std::tuple>, + public detail::base_async_visit, + std::tuple>, + public detail::base_iteration, + std::tuple> { + friend class detail::base_misc, std::tuple>; + public: - using self_type = array; - using mapped_type = Value; - using key_type = Index; - using size_type = Index; - using ygm_for_all_types = std::tuple; - using container_type = ygm::container::array_tag; - using ptr_type = typename ygm::ygm_ptr; + using self_type = array; + using mapped_type = Value; + using key_type = Index; + using size_type = Index; + using for_all_args = std::tuple; + using container_type = ygm::container::array_tag; + using ptr_type = typename ygm::ygm_ptr; + + // Pull in async_visit for use within the array + using detail::base_async_visit, + std::tuple>::async_visit; array() = delete; - array(ygm::comm& comm, const size_type size); + array(ygm::comm& comm, const size_type size) + : m_global_size(size), + m_default_value{}, + m_comm(comm), + pthis(this), + partitioner(comm, size) { + pthis.check(m_comm); - array(ygm::comm& comm, const size_type size, - const mapped_type& default_value); + resize(size); + } - array(const self_type& rhs); + array(ygm::comm& comm, const size_type size, const mapped_type& default_value) + : m_global_size(size), + m_default_value(default_value), + m_comm(comm), + pthis(this), + partitioner(comm, size) { + pthis.check(m_comm); - ~array(); + resize(size); + } - void async_set(const key_type index, const mapped_type& value); + array(const self_type& rhs) + : m_global_size(rhs.m_global_size), + m_default_value(rhs.m_default_value), + m_local_vec(rhs.m_local_vec), + m_comm(rhs.m_comm), + partitioner(m_comm, m_global_size) { + pthis.check(m_comm); + resize(m_global_size); + } + + ~array() { m_comm.barrier(); } + + void local_insert(const key_type& key, const mapped_type& value) { + m_local_vec[partitioner.local_index(key)] = value; + } + + template + void local_visit(const key_type index, Function& fn, + const VisitorArgs&... args) { + ygm::detail::interrupt_mask mask(m_comm); + if constexpr (std::is_invocable() || + std::is_invocable()) { + ygm::meta::apply_optional( + fn, std::make_tuple(pthis), + std::forward_as_tuple( + index, m_local_vec[partitioner.local_index(index)], args...)); + } else { + static_assert(ygm::detail::always_false<>, + "remote array lambda must be " + "invocable with (const " + "key_type, mapped_type &, ...) or " + "(ptr_type, mapped_type &, ...) signatures"); + } + } + + void async_set(const key_type index, const mapped_type& value) { + detail::base_async_insert_key_value, + for_all_args>::async_insert(index, + value); + } template void async_binary_op_update_value(const key_type index, const mapped_type& value, - const BinaryOp& b); + const BinaryOp& b) { + ASSERT_RELEASE(index < m_global_size); + auto updater = [](const key_type i, mapped_type& v, + const mapped_type& new_value) { + BinaryOp* binary_op; + v = (*binary_op)(v, new_value); + }; + + async_visit(index, updater, value); + } void async_bit_and(const key_type index, const mapped_type& value) { async_binary_op_update_value(index, value, std::bit_and()); @@ -76,7 +157,15 @@ class array { } template - void async_unary_op_update_value(const key_type index, const UnaryOp& u); + void async_unary_op_update_value(const key_type index, const UnaryOp& u) { + ASSERT_RELEASE(index < m_global_size); + auto updater = [](const key_type i, mapped_type& v) { + UnaryOp* u; + v = (*u)(v); + }; + + async_visit(index, updater); + } void async_increment(const key_type index) { async_unary_op_update_value(index, @@ -88,42 +177,78 @@ class array { [](const mapped_type& v) { return v - 1; }); } - template - void async_visit(const key_type index, Visitor visitor, - const VisitorArgs&... args); + const mapped_type& default_value() const; - template - void for_all(Function fn); + void resize(const size_type size, const mapped_type& fill_value) { + m_comm.barrier(); - size_type size(); + // Copy current values into temporary vector for storing in + // ygm::container::array after resizing local array structures + std::vector> tmp_values; + tmp_values.reserve(local_size()); + local_for_all( + [&tmp_values](const key_type& index, const mapped_type& value) { + tmp_values.push_back(std::make_pair(index, value)); + }); - typename ygm::ygm_ptr get_ygm_ptr() const; + m_global_size = size; + partitioner = detail::block_partitioner(m_comm, size); - int owner(const key_type index) const; + m_local_vec.resize(partitioner.local_size(), fill_value); - bool is_mine(const key_type index) const; + m_default_value = fill_value; - ygm::comm& comm(); + // Repopulate array values + for (const auto& [index, value] : tmp_values) { + if (index < size) { + async_set(index, value); + } + } - const mapped_type& default_value() const; + m_comm.barrier(); + } - void resize(const size_type size, const mapped_type& fill_value); + size_t local_size() { return partitioner.local_size(); } - void resize(const size_type size); + size_t size() const { + m_comm.barrier(); + return m_global_size; + } - private: - template - void local_for_all(Function fn); + void resize(const size_type size) { resize(size, m_default_value); } - key_type local_index(key_type index); + void local_clear() { resize(0); } - key_type global_index(key_type index); + void local_swap(self_type& other) { + m_local_vec.swap(other.m_local_vec); + std::swap(m_global_size, other.m_global_size); + std::swap(m_default_value, other.m_default_value); + std::swap(partitioner, other.partitioner); + } + + template + void local_for_all(Function fn) { + if constexpr (std::is_invocable()) { + for (int i = 0; i < m_local_vec.size(); ++i) { + key_type g_index = partitioner.global_index(i); + fn(g_index, m_local_vec[i]); + } + } else if constexpr (std::is_invocable()) { + std::for_each(std::begin(m_local_vec), std::end(m_local_vec), fn); + } else { + static_assert(ygm::detail::always_false<>, + "local array lambda must be " + "invocable with (const " + "key_type, mapped_type &) or " + "(mapped_type &) signatures"); + } + } + + detail::block_partitioner partitioner; private: size_type m_global_size; - size_type m_small_block_size; - size_type m_large_block_size; - size_type m_local_start_index; mapped_type m_default_value; std::vector m_local_vec; ygm::comm& m_comm; @@ -131,5 +256,3 @@ class array { }; } // namespace ygm::container - -#include diff --git a/include/ygm/container/detail/array.ipp b/include/ygm/container/detail/array.ipp deleted file mode 100644 index bcb7cde9..00000000 --- a/include/ygm/container/detail/array.ipp +++ /dev/null @@ -1,230 +0,0 @@ -// Copyright 2019-2021 Lawrence Livermore National Security, LLC and other YGM -// Project Developers. See the top-level COPYRIGHT file for details. -// -// SPDX-License-Identifier: MIT - -#pragma once - -namespace ygm::container { - -template -array::array(ygm::comm &comm, const size_type size) - : m_global_size(size), m_default_value{}, m_comm(comm), pthis(this) { - pthis.check(m_comm); - - resize(size); -} - -template -array::array(ygm::comm &comm, const size_type size, - const mapped_type &dv) - : m_default_value(dv), m_comm(comm), pthis(this) { - pthis.check(m_comm); - - resize(size); -} - -template -array::array(const self_type &rhs) - : m_default_value(rhs.m_default_value), - m_comm(rhs.m_comm), - m_global_size(rhs.m_global_size), - m_small_block_size(rhs.m_small_block_size), - m_large_block_size(rhs.m_large_block_size), - m_local_start_index(rhs.m_local_start_index), - m_local_vec(rhs.m_local_vec), - pthis(this) { - pthis.check(m_comm); -} - -template -array::~array() { - m_comm.barrier(); -} - -template -void array::resize(const size_type size, - const mapped_type &fill_value) { - m_comm.barrier(); - - m_global_size = size; - m_small_block_size = size / m_comm.size(); - m_large_block_size = m_small_block_size + ((size / m_comm.size()) > 0); - - m_local_vec.resize( - m_small_block_size + (m_comm.rank() < (size % m_comm.size())), - fill_value); - - if (m_comm.rank() < (size % m_comm.size())) { - m_local_start_index = m_comm.rank() * m_large_block_size; - } else { - m_local_start_index = - (size % m_comm.size()) * m_large_block_size + - (m_comm.rank() - (size % m_comm.size())) * m_small_block_size; - } - - m_comm.barrier(); -} - -template -void array::resize(const size_type size) { - resize(size, m_default_value); -} - -template -void array::async_set(const key_type index, - const mapped_type &value) { - ASSERT_RELEASE(index < m_global_size); - auto putter = [](auto parray, const key_type i, const mapped_type &v) { - key_type l_index = parray->local_index(i); - ASSERT_RELEASE(l_index < parray->m_local_vec.size()); - parray->m_local_vec[l_index] = v; - }; - - int dest = owner(index); - m_comm.async(dest, putter, pthis, index, value); -} - -template -template -void array::async_binary_op_update_value(const key_type index, - const mapped_type &value, - const BinaryOp &b) { - ASSERT_RELEASE(index < m_global_size); - auto updater = [](const key_type i, mapped_type &v, - const mapped_type &new_value) { - BinaryOp *binary_op; - v = (*binary_op)(v, new_value); - }; - - async_visit(index, updater, value); -} -template -template -void array::async_unary_op_update_value(const key_type index, - const UnaryOp &u) { - ASSERT_RELEASE(index < m_global_size); - auto updater = [](const key_type i, mapped_type &v) { - UnaryOp *u; - v = (*u)(v); - }; - - async_visit(index, updater); -} - -template -template -void array::async_visit(const key_type index, Visitor visitor, - const VisitorArgs &...args) { - ASSERT_RELEASE(index < m_global_size); - int dest = owner(index); - auto visit_wrapper = [](auto parray, const key_type i, - const VisitorArgs &...args) { - key_type l_index = parray->local_index(i); - ASSERT_RELEASE(l_index < parray->m_local_vec.size()); - mapped_type &l_value = parray->m_local_vec[l_index]; - Visitor *vis = nullptr; - if constexpr (std::is_invocable() || - std::is_invocable()) { - ygm::meta::apply_optional(*vis, std::make_tuple(parray), - std::forward_as_tuple(i, l_value, args...)); - } else { - static_assert( - ygm::detail::always_false<>, - "remote array lambda signature must be invocable with (const " - "&key_type, mapped_type&, ...) or (ptr_type, const " - "&key_type, mapped_type&, ...) signatures"); - } - }; - - m_comm.async(dest, visit_wrapper, pthis, index, - std::forward(args)...); -} - -template -template -void array::for_all(Function fn) { - m_comm.barrier(); - local_for_all(fn); -} - -template -template -void array::local_for_all(Function fn) { - if constexpr (std::is_invocable()) { - for (int i = 0; i < m_local_vec.size(); ++i) { - key_type g_index = global_index(i); - fn(g_index, m_local_vec[i]); - } - } else if constexpr (std::is_invocable()) { - std::for_each(std::begin(m_local_vec), std::end(m_local_vec), fn); - } else { - static_assert(ygm::detail::always_false<>, - "local array lambda must be invocable with (const " - "key_type, mapped_type &) or (mapped_type &) signatures"); - } -} - -template -typename array::size_type array::size() { - return m_global_size; -} - -template -typename array::ptr_type array::get_ygm_ptr() - const { - return pthis; -} - -template -ygm::comm &array::comm() { - return m_comm; -} - -template -const typename array::mapped_type & -array::default_value() const { - return m_default_value; -} - -template -int array::owner(const key_type index) const { - int to_return; - // Owner depends on whether index is before switching to small blocks - if (index < (m_global_size % m_comm.size()) * m_large_block_size) { - to_return = index / m_large_block_size; - } else { - to_return = (m_global_size % m_comm.size()) + - (index - (m_global_size % m_comm.size()) * m_large_block_size) / - m_small_block_size; - } - ASSERT_RELEASE((to_return >= 0) && (to_return < m_comm.size())); - - return to_return; -} - -template -bool array::is_mine(const key_type index) const { - return owner(index) == m_comm.rank(); -} - -template -typename array::key_type array::local_index( - const key_type index) { - key_type to_return = index - m_local_start_index; - ASSERT_RELEASE((to_return >= 0) && (to_return <= m_small_block_size)); - return to_return; -} - -template -typename array::key_type array::global_index( - const key_type index) { - key_type to_return; - return m_local_start_index + index; -} - -}; // namespace ygm::container diff --git a/include/ygm/container/detail/block_partitioner.hpp b/include/ygm/container/detail/block_partitioner.hpp new file mode 100644 index 00000000..e581cdf6 --- /dev/null +++ b/include/ygm/container/detail/block_partitioner.hpp @@ -0,0 +1,86 @@ +// Copyright 2019-2021 Lawrence Livermore National Security, LLC and other YGM +// Project Developers. See the top-level COPYRIGHT file for details. +// +// SPDX-License-Identifier: MIT + +#pragma once +#include +#include + +#include + +namespace ygm::container::detail { + +template +struct block_partitioner { + using index_type = Index; + + block_partitioner(ygm::comm &comm, index_type partitioned_size) + : m_comm_size(comm.size()), + m_comm_rank(comm.rank()), + m_partitioned_size(partitioned_size) { + m_small_block_size = partitioned_size / m_comm_size; + m_large_block_size = + m_small_block_size + ((partitioned_size / m_comm_size) > 0); + + if (m_comm_rank < (partitioned_size % m_comm_size)) { + m_local_start_index = m_comm_rank * m_large_block_size; + } else { + m_local_start_index = + (partitioned_size % m_comm_size) * m_large_block_size + + (m_comm_rank - (partitioned_size % m_comm_size)) * m_small_block_size; + } + + m_local_size = + m_small_block_size + (m_comm_rank < (m_partitioned_size % m_comm_size)); + + if (m_comm_rank < (m_partitioned_size % m_comm_size)) { + m_local_start_index = m_comm_rank * m_large_block_size; + } else { + m_local_start_index = + (m_partitioned_size % m_comm_size) * m_large_block_size + + (m_comm_rank - (m_partitioned_size % m_comm_size)) * + m_small_block_size; + } + } + + int owner(const index_type &index) const { + int to_return; + // Owner depends on whether index is before switching to small blocks + if (index < (m_partitioned_size % m_comm_size) * m_large_block_size) { + to_return = index / m_large_block_size; + } else { + to_return = + (m_partitioned_size % m_comm_size) + + (index - (m_partitioned_size % m_comm_size) * m_large_block_size) / + m_small_block_size; + } + ASSERT_RELEASE((to_return >= 0) && (to_return < m_comm_size)); + + return to_return; + } + + index_type local_index(const index_type &global_index) { + index_type to_return = global_index - m_local_start_index; + ASSERT_RELEASE((to_return >= 0) && (to_return <= m_small_block_size)); + return to_return; + } + + index_type global_index(const index_type &local_index) { + index_type to_return; + return m_local_start_index + local_index; + } + + index_type local_size() { return m_local_size; } + + private: + int m_comm_size; + int m_comm_rank; + index_type m_partitioned_size; + index_type m_small_block_size; + index_type m_large_block_size; + index_type m_local_size; + index_type m_local_start_index; +}; + +} // namespace ygm::container::detail diff --git a/test/test_array.cpp b/test/test_array.cpp index ffb2069e..71f11414 100644 --- a/test/test_array.cpp +++ b/test/test_array.cpp @@ -23,7 +23,7 @@ int main(int argc, char **argv) { std::is_same_v); static_assert( std::is_same_v< - decltype(arr)::ygm_for_all_types, + decltype(arr)::for_all_args, std::tuple >); } @@ -227,5 +227,107 @@ int main(int argc, char **argv) { }); } + // Test resize + { + int large_size = 64; + int small_size = 32; + + ygm::container::array arr(world, large_size); + + if (world.rank0()) { + for (int i = 0; i < large_size; ++i) { + arr.async_set(i, 2 * i); + } + } + + world.barrier(); + + ASSERT_RELEASE(arr.size() == large_size); + arr.for_all([](const auto &index, const auto &value) { + ASSERT_RELEASE(value == 2 * index); + }); + + arr.resize(small_size); + + ASSERT_RELEASE(arr.size() == small_size); + arr.for_all([](const auto &index, const auto &value) { + ASSERT_RELEASE(value == 2 * index); + }); + + arr.resize(large_size); + + ASSERT_RELEASE(arr.size() == large_size); + arr.for_all([&small_size](const auto &index, const auto &value) { + if (index < small_size) { + ASSERT_RELEASE(value == 2 * index); + } + }); + } + + // Test clear + { + int initial_size = 64; + + ygm::container::array arr(world, initial_size); + + if (world.rank0()) { + for (int i = 0; i < initial_size; ++i) { + arr.async_set(i, 2 * i); + } + } + + world.barrier(); + + ASSERT_RELEASE(arr.size() == initial_size); + + arr.clear(); + + ASSERT_RELEASE(arr.size() == 0); + } + + // Test swap + { + int size1 = 32; + int size2 = 48; + + ygm::container::array arr1(world, size1); + ygm::container::array arr2(world, size2); + + if (world.rank0()) { + for (int i = 0; i < size1; ++i) { + arr1.async_set(i, 2 * i); + } + for (int i = 0; i < size2; ++i) { + arr2.async_set(i, 3 * i + 1); + } + } + + world.barrier(); + + ASSERT_RELEASE(arr1.size() == size1); + ASSERT_RELEASE(arr2.size() == size2); + + arr1.for_all([](const auto &index, const auto &value) { + ASSERT_RELEASE(value == 2 * index); + }); + + arr2.for_all([](const auto &index, const auto &value) { + ASSERT_RELEASE(value == 3 * index + 1); + }); + + arr1.swap(arr2); + + ASSERT_RELEASE(arr1.size() == size2); + ASSERT_RELEASE(arr2.size() == size1); + + arr1.for_all([](const auto &index, const auto &value) { + ASSERT_RELEASE(value == 3 * index + 1); + }); + + arr2.for_all([](const auto &index, const auto &value) { + ASSERT_RELEASE(value == 2 * index); + }); + } + return 0; }