Skip to content

Commit

Permalink
Fixing race in collective operations
Browse files Browse the repository at this point in the history
- adding test
  • Loading branch information
hkaiser committed Mar 26, 2024
1 parent 560777f commit f95805d
Show file tree
Hide file tree
Showing 5 changed files with 871 additions and 36 deletions.
34 changes: 22 additions & 12 deletions libs/core/futures/src/future_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,13 @@ namespace hpx::lcos::detail {
///////////////////////////////////////////////////////////////////////////
struct handle_continuation_recursion_count
{
handle_continuation_recursion_count() noexcept
: count_(threads::get_continuation_recursion_count())
handle_continuation_recursion_count() = default;

std::size_t increment()
{
++count_;
HPX_ASSERT(count_ == nullptr);
count_ = &threads::get_continuation_recursion_count();
return ++*count_;
}

handle_continuation_recursion_count(
Expand All @@ -59,10 +62,14 @@ namespace hpx::lcos::detail {

~handle_continuation_recursion_count()
{
--count_;
if (count_ != nullptr)
{
--*count_;
}
}

std::size_t& count_;
private:
std::size_t* count_ = nullptr;
};

///////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -252,17 +259,20 @@ namespace hpx::lcos::detail {
{
// We need to run the completion on a new thread if we are on a non HPX
// thread.
bool const is_hpx_thread = nullptr != hpx::threads::get_self_ptr();
bool recurse_asynchronously = false;

handle_continuation_recursion_count cnt;
if (is_hpx_thread)
{
#if defined(HPX_HAVE_THREADS_GET_STACK_POINTER)
bool recurse_asynchronously =
!this_thread::has_sufficient_stack_space();
recurse_asynchronously = !this_thread::has_sufficient_stack_space();
#else
handle_continuation_recursion_count const cnt;
bool recurse_asynchronously =
cnt.count_ > HPX_CONTINUATION_MAX_RECURSION_DEPTH ||
(hpx::threads::get_self_ptr() == nullptr);
recurse_asynchronously =
cnt.increment() > HPX_CONTINUATION_MAX_RECURSION_DEPTH;
#endif
}

bool const is_hpx_thread = nullptr != hpx::threads::get_self_ptr();
if (!is_hpx_thread || !recurse_asynchronously)
{
// directly execute continuation on this thread
Expand Down
2 changes: 1 addition & 1 deletion libs/core/lcos_local/include/hpx/lcos_local/and_gate.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ namespace hpx::lcos::local {
// Note: This type is not thread-safe. It has to be protected from
// concurrent access by different threads by the code using instances
// of this type.
struct and_gate : public base_and_gate<hpx::no_mutex>
struct and_gate : base_and_gate<hpx::no_mutex>
{
private:
using base_type = base_and_gate<hpx::no_mutex>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,8 @@ namespace hpx::collectives::detail {
};

private:
std::size_t get_num_sites(std::size_t num_values) const noexcept
[[nodiscard]] constexpr std::size_t get_num_sites(
std::size_t num_values) const noexcept
{
return num_values == static_cast<std::size_t>(-1) ? num_sites_ :
num_values;
Expand Down Expand Up @@ -231,19 +232,22 @@ namespace hpx::collectives::detail {
std::size_t generation, std::size_t capacity, F&& f, Lock& l)
{
HPX_ASSERT_OWNS_LOCK(l);
auto sf = gate_.get_shared_future(l);

traits::detail::get_shared_state(sf)->reserve_callbacks(
get_num_sites(capacity));

auto fut = sf.then(hpx::launch::sync, HPX_FORWARD(F, f));

// Wait for the requested generation to be processed.
gate_.synchronize(generation == static_cast<std::size_t>(-1) ?
gate_.generation(l) :
generation,
l);

return fut;
// Get future from gate only after synchronization as otherwise we
// may get a future returned that does not belong to the requested
// generation.
auto sf = gate_.get_shared_future(l);

traits::detail::get_shared_state(sf)->reserve_callbacks(
get_num_sites(capacity));

return sf.then(hpx::launch::sync, HPX_FORWARD(F, f));
}

template <typename Lock>
Expand All @@ -262,9 +266,16 @@ namespace hpx::collectives::detail {
"collective operation {}, which {}, generation {}.",
basename_, operation, which, generation);
}
current_operation_ = operation;

if (generation == static_cast<std::size_t>(-1) ||
generation == gate_.generation(l))
{
current_operation_ = operation;
}

return true;
}

return false;
}

Expand All @@ -284,13 +295,19 @@ namespace hpx::collectives::detail {
// This callback will be invoked once for each participating
// site after all sites have checked in.

// On exit, keep track of number of invocations of this
// callback.
auto on_exit = hpx::experimental::scope_exit(
[this] { ++on_ready_count_; });

f.get(); // propagate any exceptions

// It does not matter whether the lock will be acquired here. It
// either is still being held by the surrounding logic or is
// re-acquired here (if `on_ready` happens to run on a new
// thread asynchronously).
std::unique_lock l(mtx_, std::try_to_lock);
[[maybe_unused]] util::ignore_while_checking il(&l);

// Verify that there is no overlap between different types of
// operations on the same communicator.
Expand All @@ -315,19 +332,14 @@ namespace hpx::collectives::detail {
l.unlock();
HPX_THROW_EXCEPTION(hpx::error::invalid_status,
"communicator::handle_data::on_ready",
"communictor {}: sequencing error, an excessive "
"communicator {}: sequencing error, an excessive "
"number of on_ready callbacks have been invoked before "
"the end of the collective operation {}, which {}, "
"generation {}. Expected count {}, received count {}.",
basename_, operation, which, generation,
on_ready_count_, num_sites_);
}

// On exit, keep track of number of invocations of this
// callback.
auto on_exit = hpx::experimental::scope_exit(
[this] { ++on_ready_count_; });

if constexpr (!std::is_same_v<std::nullptr_t,
std::decay_t<Finalizer>>)
{
Expand All @@ -338,8 +350,6 @@ namespace hpx::collectives::detail {
else
{
HPX_UNUSED(this);
HPX_UNUSED(which);
HPX_UNUSED(generation);
HPX_UNUSED(num_values);
HPX_UNUSED(finalizer);
}
Expand Down Expand Up @@ -373,7 +383,7 @@ namespace hpx::collectives::detail {

if constexpr (!std::is_same_v<std::nullptr_t, std::decay_t<Step>>)
{
// call provided step function for each invocation site
// Call provided step function for each invocation site.
HPX_FORWARD(Step, step)(access_data<Data>(num_values), which);
}

Expand All @@ -399,7 +409,7 @@ namespace hpx::collectives::detail {
"been invoked at the end of the collective {} "
"operation. Expected count {}, received count {}, "
"which {}, generation {}.",
*operation, on_ready_count_, num_sites_, which,
operation, on_ready_count_, num_sites_, which,
generation);
return;
}
Expand All @@ -416,7 +426,7 @@ namespace hpx::collectives::detail {
return f;
}

// protect against vector<bool> idiosyncrasies
// Protect against vector<bool> idiosyncrasies.
template <typename ValueType, typename Data>
static constexpr decltype(auto) handle_bool(Data&& data) noexcept
{
Expand All @@ -433,15 +443,15 @@ namespace hpx::collectives::detail {
template <typename Communicator, typename Operation>
friend struct hpx::traits::communication_operation;

mutex_type mtx_;
hpx::unique_any_nonser data_;
hpx::lcos::local::and_gate gate_;
std::size_t const num_sites_;
std::size_t on_ready_count_ = 0;
char const* current_operation_ = nullptr;
char const* basename_ = nullptr;
mutex_type mtx_;
bool needs_initialization_ = true;
bool data_available_ = false;
char const* basename_ = nullptr;
};
} // namespace hpx::collectives::detail

Expand Down
3 changes: 2 additions & 1 deletion libs/full/collectives/tests/unit/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2019-2023 Hartmut Kaiser
# Copyright (c) 2019-2024 Hartmut Kaiser
#
# SPDX-License-Identifier: BSL-1.0
# Distributed under the Boost Software License, Version 1.0. (See accompanying
Expand All @@ -23,6 +23,7 @@ if(HPX_WITH_NETWORKING)
set(tests
${tests}
broadcast_direct
concurrent_collectives
exclusive_scan_
gather
inclusive_scan_
Expand Down
Loading

0 comments on commit f95805d

Please sign in to comment.