diff --git a/xla/backends/cpu/collectives/in_process_communicator.cc b/xla/backends/cpu/collectives/in_process_communicator.cc index b5ab1396e38477..2d4dc88f9ef27a 100644 --- a/xla/backends/cpu/collectives/in_process_communicator.cc +++ b/xla/backends/cpu/collectives/in_process_communicator.cc @@ -344,7 +344,7 @@ absl::Status InProcessCommunicator::AllReduce(se::DeviceMemoryBase send_buffer, std::string name = absl::StrCat("all reduce ", key.ToString()); AllReduceParticipant partiticipant{rank_, send_buffer, recv_buffer}; - return RendezvousSingle( + return Rendezvous( name, key, partiticipant, key.num_local_participants, std::bind(AllReduceOp, dtype, count, reduction_kind, std::placeholders::_1)); @@ -362,7 +362,7 @@ absl::Status InProcessCommunicator::CollectivePermute( recv_buffer}; size_t num_bytes = count * primitive_util::ByteWidth(dtype); - return RendezvousSingle( + return Rendezvous( name, key, partiticipant, key.num_local_participants, std::bind(CollectivePermuteOp, num_bytes, std::placeholders::_1)); } @@ -380,7 +380,7 @@ absl::Status InProcessCommunicator::AllToAll( {recv_buffers.begin(), recv_buffers.end()}}; size_t num_bytes = count * primitive_util::ByteWidth(dtype); - return RendezvousSingle( + return Rendezvous( name, key, partiticipant, key.num_local_participants, std::bind(AllToAllOp, num_bytes, std::placeholders::_1)); } @@ -396,7 +396,7 @@ absl::Status InProcessCommunicator::AllGather(se::DeviceMemoryBase send_buffer, AllGatherParticipant partiticipant{rank_, send_buffer, recv_buffer}; size_t num_bytes = count * primitive_util::ByteWidth(dtype); - return RendezvousSingle( + return Rendezvous( name, key, partiticipant, key.num_local_participants, std::bind(AllGatherOp, num_bytes, std::placeholders::_1)); } @@ -411,7 +411,7 @@ absl::Status InProcessCommunicator::ReduceScatter( std::string name = absl::StrCat("reduce scatter ", key.ToString()); ReduceScatterParticipant partiticipant{rank_, send_buffer, recv_buffer}; - return RendezvousSingle( + return Rendezvous( name, key, partiticipant, key.num_local_participants, std::bind(ReduceScatterOp, dtype, count, reduction_kind, std::placeholders::_1)); diff --git a/xla/backends/gpu/collectives/gpu_cliques.cc b/xla/backends/gpu/collectives/gpu_cliques.cc index a06cb864d556fd..77398835588d82 100644 --- a/xla/backends/gpu/collectives/gpu_cliques.cc +++ b/xla/backends/gpu/collectives/gpu_cliques.cc @@ -292,7 +292,7 @@ InitializeGpuClique(GpuCollectives* collectives, se::StreamExecutor* device, // processes are not able to synchronize device activity. RendezvousArg rendezvous_arg = std::make_pair(device_rank, synchronized); - return RendezvousSingle>( + return Rendezvous>( initialization_rendezvous_name, rendezvous_key, rendezvous_arg, num_local_participants, initialize, WarnStuckTimeout(), TerminateTimeout()); @@ -431,7 +431,7 @@ InitializeGpuClique(GpuCollectives* collectives, se::StreamExecutor* device, rank.value(), clique_key.ToString(), run_id.ToInt(), parent_clique_key.ToString()); - return RendezvousSingle>( + return Rendezvous>( initialization_rendezvous_name, rendezvous_key, rank_pair, num_local_participants, split, WarnStuckTimeout(), TerminateTimeout()); } @@ -466,7 +466,7 @@ absl::StatusOr> AcquireGpuClique( TF_ASSIGN_OR_RETURN( std::shared_ptr clique, - RendezvousSingle>( + Rendezvous>( rendezvous_name, rendezvous_key, num_local_participants, [&] { tsl::profiler::TraceMe trace("LockGpuClique"); diff --git a/xla/debug_options_flags.cc b/xla/debug_options_flags.cc index b78b87c8a15dff..b4ea4e0a4a04b8 100644 --- a/xla/debug_options_flags.cc +++ b/xla/debug_options_flags.cc @@ -2124,13 +2124,13 @@ void MakeDebugOptionsFlags(std::vector* flag_list, int32_setter_for( &DebugOptions::set_xla_gpu_executable_warn_stuck_timeout_seconds), debug_options->xla_gpu_executable_warn_stuck_timeout_seconds(), - "Set timeout for RendezvousSingle stuck warning")); + "Set timeout for Rendezvous stuck warning")); flag_list->push_back(tsl::Flag( "xla_gpu_executable_terminate_timeout", int32_setter_for( &DebugOptions::set_xla_gpu_executable_terminate_timeout_seconds), debug_options->xla_gpu_executable_terminate_timeout_seconds(), - "Set timeout for RendezvousSingle termination")); + "Set timeout for Rendezvous termination")); flag_list->push_back(tsl::Flag( "xla_gpu_experimental_disable_binary_libraries", bool_setter_for( diff --git a/xla/service/gpu/gpu_executable.cc b/xla/service/gpu/gpu_executable.cc index 9e0a5bebcc4a2d..f656e9691a0d3a 100644 --- a/xla/service/gpu/gpu_executable.cc +++ b/xla/service/gpu/gpu_executable.cc @@ -626,7 +626,7 @@ absl::Status RendezvousAfterInitialization( run_options->device_ordinal(), run_options->run_options().run_id().ToInt()); - RendezvousSingle( + Rendezvous( rendezvous_name, rendezvous_key, num_local_participants, absl::Seconds( debug_options diff --git a/xla/service/gpu/runtime/nccl_collective_thunk.cc b/xla/service/gpu/runtime/nccl_collective_thunk.cc index 1c839dcb18c9bf..47211e6f437c91 100644 --- a/xla/service/gpu/runtime/nccl_collective_thunk.cc +++ b/xla/service/gpu/runtime/nccl_collective_thunk.cc @@ -479,10 +479,10 @@ absl::Status NcclCollectiveThunk::ExecuteOnStream(const ExecuteParams& params) { "first call to collective operation %d; run_id=%d", config().op_id, params.collective_params->run_id.ToInt()); - RendezvousSingle(first_call_rendezvous_flag_, rendezvous_name, - rendezvous_key, num_local_participants, - /*warn_stuck_timeout=*/absl::Seconds(20), - /*terminate_timeout=*/absl::Seconds(40)); + Rendezvous(first_call_rendezvous_flag_, rendezvous_name, rendezvous_key, + num_local_participants, + /*warn_stuck_timeout=*/absl::Seconds(20), + /*terminate_timeout=*/absl::Seconds(40)); } return absl::OkStatus(); diff --git a/xla/service/gpu/runtime/nccl_collective_thunk.h b/xla/service/gpu/runtime/nccl_collective_thunk.h index acdb18d68a3fc3..5b5ba1fcf26995 100644 --- a/xla/service/gpu/runtime/nccl_collective_thunk.h +++ b/xla/service/gpu/runtime/nccl_collective_thunk.h @@ -210,7 +210,7 @@ class NcclCollectiveThunk : public Thunk { // // TODO(ezhulenev): Try to move this flag to NCCL clique as we need to make // sure that all NCCL resources are allocated just once. - RendezvousSingleFlag first_call_rendezvous_flag_; + RendezvousFlag first_call_rendezvous_flag_; }; //===----------------------------------------------------------------------===// diff --git a/xla/service/rendezvous.cc b/xla/service/rendezvous.cc index a22c5537a4d451..e9c88dbca2e5a6 100644 --- a/xla/service/rendezvous.cc +++ b/xla/service/rendezvous.cc @@ -137,13 +137,12 @@ inline constexpr int32_t kPending = 0; inline constexpr int32_t kCompleted = std::numeric_limits::max(); } // namespace -RendezvousSingleFlag::RendezvousSingleFlag() : state_(kPending) {} +RendezvousFlag::RendezvousFlag() : state_(kPending) {} -RendezvousSingleFlag::InFlightRendezvous::InFlightRendezvous( - RendezvousSingleFlag* flag) +RendezvousFlag::InFlightRendezvous::InFlightRendezvous(RendezvousFlag* flag) : flag_(flag) {} -RendezvousSingleFlag::InFlightRendezvous::~InFlightRendezvous() { +RendezvousFlag::InFlightRendezvous::~InFlightRendezvous() { if (flag_ == nullptr) return; // Reload state and use CAS to decide if we are the one who @@ -162,11 +161,11 @@ RendezvousSingleFlag::InFlightRendezvous::~InFlightRendezvous() { } } -RendezvousSingleFlag::InFlightRendezvous::operator bool() const { +RendezvousFlag::InFlightRendezvous::operator bool() const { return flag_ != nullptr; } -RendezvousSingleFlag::InFlightRendezvous RendezvousSingleFlag::TryJoin() { +RendezvousFlag::InFlightRendezvous RendezvousFlag::TryJoin() { // If `state_` is `kCompleted` it means that we have at least one completed // rendezvous for this flag and can skip it. if (state_.load() == kCompleted) return InFlightRendezvous(nullptr); @@ -184,8 +183,6 @@ RendezvousSingleFlag::InFlightRendezvous RendezvousSingleFlag::TryJoin() { return InFlightRendezvous(this); } -bool RendezvousSingleFlag::IsCompleted() const { - return state_.load() == kCompleted; -} +bool RendezvousFlag::IsCompleted() const { return state_.load() == kCompleted; } } // namespace xla diff --git a/xla/service/rendezvous.h b/xla/service/rendezvous.h index b19776601e8943..ffd4c431003726 100644 --- a/xla/service/rendezvous.h +++ b/xla/service/rendezvous.h @@ -85,14 +85,14 @@ using RendezvousResultType = typename RendezvousResult::Type; // all threads receive the result. Rendezvous must have a human readable name to // make easy to debug stuck and timed out attempts. template -RendezvousResultType RendezvousSingle( +RendezvousResultType Rendezvous( absl::string_view name, const K& key, const V& value, size_t num_threads, Fn fn, absl::Duration warn_stuck_timeout = absl::InfiniteDuration(), absl::Duration terminate_timeout = absl::InfiniteDuration()); // A rendezvous for a group of threads that do not have any value arguments. template -RendezvousResultType RendezvousSingle( +RendezvousResultType Rendezvous( absl::string_view name, const K& key, size_t num_threads, Fn fn, absl::Duration warn_stuck_timeout = absl::InfiniteDuration(), absl::Duration terminate_timeout = absl::InfiniteDuration()); @@ -100,14 +100,13 @@ RendezvousResultType RendezvousSingle( // A rendezvous for a group of threads that do not have any computation to run // and simply acts as a barrier for a group of thread. template -void RendezvousSingle( - absl::string_view name, const K& key, size_t num_threads, - absl::Duration warn_stuck_timeout = absl::InfiniteDuration(), - absl::Duration terminate_timeout = absl::InfiniteDuration()); +void Rendezvous(absl::string_view name, const K& key, size_t num_threads, + absl::Duration warn_stuck_timeout = absl::InfiniteDuration(), + absl::Duration terminate_timeout = absl::InfiniteDuration()); -// An `std::once_flag`-like primitive for executing RendezvousSingle operations. +// An `std::once_flag`-like primitive for executing Rendezvous operations. // -// RendezvousSingleFlag guarantees that all or none participants in a rendezvous +// RendezvousFlag guarantees that all or none participants in a rendezvous // join the rendezvous process and once rendezvous is completed flag marked as // `completed` and all further rendezvous using this flag will be skipped. It // has a weaker than exactly-once guarantee and multiple racing rendezvous can @@ -119,17 +118,17 @@ void RendezvousSingle( // and prefer simpler implementation with weaker guarantees. // // See: https://en.cppreference.com/w/cpp/thread/once_flag -class RendezvousSingleFlag { +class RendezvousFlag { public: - RendezvousSingleFlag(); + RendezvousFlag(); - RendezvousSingleFlag(const RendezvousSingleFlag&) = delete; - RendezvousSingleFlag& operator=(const RendezvousSingleFlag&) = delete; + RendezvousFlag(const RendezvousFlag&) = delete; + RendezvousFlag& operator=(const RendezvousFlag&) = delete; // RAII wrapper to exit from in-flight rendezvous when destructed. class InFlightRendezvous { public: - explicit InFlightRendezvous(RendezvousSingleFlag* flag); + explicit InFlightRendezvous(RendezvousFlag* flag); ~InFlightRendezvous(); InFlightRendezvous(const InFlightRendezvous&) = delete; @@ -138,7 +137,7 @@ class RendezvousSingleFlag { operator bool() const; // NOLINT private: - RendezvousSingleFlag* flag_; + RendezvousFlag* flag_; }; // Returns InFlightRendezvous convertible to `true` if the caller should join @@ -159,8 +158,8 @@ class RendezvousSingleFlag { // rendezvous. If rendezvous will not be executed it will return empty shared // pointer result. template -RendezvousResultType RendezvousSingle( - RendezvousSingleFlag& flag, absl::string_view name, const K& key, +RendezvousResultType Rendezvous( + RendezvousFlag& flag, absl::string_view name, const K& key, size_t num_threads, Fn fn, absl::Duration warn_stuck_timeout = absl::InfiniteDuration(), absl::Duration terminate_timeout = absl::InfiniteDuration()); @@ -169,11 +168,10 @@ RendezvousResultType RendezvousSingle( // not in `completed` state and will switch it to `completed` after finishing a // rendezvous. template -void RendezvousSingle( - RendezvousSingleFlag& flag, absl::string_view name, const K& key, - size_t num_threads, - absl::Duration warn_stuck_timeout = absl::InfiniteDuration(), - absl::Duration terminate_timeout = absl::InfiniteDuration()); +void Rendezvous(RendezvousFlag& flag, absl::string_view name, const K& key, + size_t num_threads, + absl::Duration warn_stuck_timeout = absl::InfiniteDuration(), + absl::Duration terminate_timeout = absl::InfiniteDuration()); //===----------------------------------------------------------------------===// // Internal implementation details. @@ -291,11 +289,10 @@ void AwaitAndLogIfStuck(RendezvousStateSynchronization& state, int32_t id, //===----------------------------------------------------------------------===// template -RendezvousResultType RendezvousSingle(absl::string_view name, const K& key, - const V& value, size_t num_threads, - Fn fn, - absl::Duration warn_stuck_timeout, - absl::Duration terminate_timeout) { +RendezvousResultType Rendezvous(absl::string_view name, const K& key, + const V& value, size_t num_threads, Fn fn, + absl::Duration warn_stuck_timeout, + absl::Duration terminate_timeout) { // Check that `fn` is callable with a span of values and returns `R`. static_assert(std::is_invocable_r_v>, "invalid rendezvous function signature"); @@ -319,7 +316,7 @@ RendezvousResultType RendezvousSingle(absl::string_view name, const K& key, tsl::profiler::TraceMe trace([&] { return tsl::profiler::TraceMeEncode( - "RendezvousSingle", + "Rendezvous", {{"num_threads", num_threads}, {"name", name}, {"id", id}}); }); @@ -355,46 +352,44 @@ RendezvousResultType RendezvousSingle(absl::string_view name, const K& key, } template -RendezvousResultType RendezvousSingle(absl::string_view name, const K& key, - size_t num_threads, Fn fn, - absl::Duration warn_stuck_timeout, - absl::Duration terminate_timeout) { - return RendezvousSingle( +RendezvousResultType Rendezvous(absl::string_view name, const K& key, + size_t num_threads, Fn fn, + absl::Duration warn_stuck_timeout, + absl::Duration terminate_timeout) { + return Rendezvous( name, key, std::nullopt, num_threads, [fn](auto) { return fn(); }, warn_stuck_timeout, terminate_timeout); } template -void RendezvousSingle(absl::string_view name, const K& key, size_t num_threads, - absl::Duration warn_stuck_timeout, - absl::Duration terminate_timeout) { - RendezvousSingle( +void Rendezvous(absl::string_view name, const K& key, size_t num_threads, + absl::Duration warn_stuck_timeout, + absl::Duration terminate_timeout) { + Rendezvous( name, key, std::nullopt, num_threads, [](auto) { return std::nullopt; }, warn_stuck_timeout, terminate_timeout); } template -RendezvousResultType RendezvousSingle(RendezvousSingleFlag& flag, - absl::string_view name, const K& key, - size_t num_threads, Fn fn, - absl::Duration warn_stuck_timeout, - absl::Duration terminate_timeout) { +RendezvousResultType Rendezvous(RendezvousFlag& flag, absl::string_view name, + const K& key, size_t num_threads, Fn fn, + absl::Duration warn_stuck_timeout, + absl::Duration terminate_timeout) { if (auto in_flight_rendezvous = flag.TryJoin()) { - return RendezvousSingle(name, key, num_threads, std::move(fn), - warn_stuck_timeout, terminate_timeout); + return Rendezvous(name, key, num_threads, std::move(fn), + warn_stuck_timeout, terminate_timeout); } else { return RendezvousResult::Empty(); } } template -void RendezvousSingle(RendezvousSingleFlag& flag, absl::string_view name, - const K& key, size_t num_threads, - absl::Duration warn_stuck_timeout, - absl::Duration terminate_timeout) { +void Rendezvous(RendezvousFlag& flag, absl::string_view name, const K& key, + size_t num_threads, absl::Duration warn_stuck_timeout, + absl::Duration terminate_timeout) { if (auto in_flight_rendezvous = flag.TryJoin()) { - RendezvousSingle(name, key, num_threads, warn_stuck_timeout, - terminate_timeout); + Rendezvous(name, key, num_threads, warn_stuck_timeout, + terminate_timeout); } } diff --git a/xla/service/rendezvous_test.cc b/xla/service/rendezvous_test.cc index 867d24971f078b..c47550a63de17a 100644 --- a/xla/service/rendezvous_test.cc +++ b/xla/service/rendezvous_test.cc @@ -41,8 +41,7 @@ tsl::thread::ThreadPool CreateThreadPool(int32_t size) { } TEST(RendezvousTest, OneParticipant) { - auto result = - RendezvousSingle("rendezvous_test", 0, 1, [] { return 42; }); + auto result = Rendezvous("rendezvous_test", 0, 1, [] { return 42; }); ASSERT_EQ(*result, 42); } @@ -53,7 +52,7 @@ TEST(RendezvousTest, TwoParticipants) { auto task = [&](int32_t id) { return [&, id] { results[id] = - RendezvousSingle("rendezvous_test", 0, 2, [] { return 42; }); + Rendezvous("rendezvous_test", 0, 2, [] { return 42; }); counter.DecrementCount(); }; }; @@ -81,7 +80,7 @@ TEST(RendezvousTest, TwoParticipantsWithValues) { auto task = [&](int32_t id) { return [&, id] { results[id] = - RendezvousSingle("rendezvous_test", 0, id, 2, accumulate); + Rendezvous("rendezvous_test", 0, id, 2, accumulate); counter.DecrementCount(); }; }; @@ -103,7 +102,7 @@ TEST(RendezvousTest, RepeatRendezvous) { absl::BlockingCounter counter(2); auto task = [&] { - RendezvousSingle("rendezvous_test", i, 2, [] { return 42; }); + Rendezvous("rendezvous_test", i, 2, [] { return 42; }); counter.DecrementCount(); }; @@ -119,8 +118,8 @@ TEST(RendezvousTest, ReturningStatusOr) { auto task = [&](int32_t id) { return [&, id] { - results[id] = RendezvousSingle>( - "rendezvous_test", 0, 2, [] { return 42; }); + results[id] = Rendezvous>("rendezvous_test", 0, 2, + [] { return 42; }); counter.DecrementCount(); }; }; @@ -135,8 +134,8 @@ TEST(RendezvousTest, ReturningStatusOr) { ASSERT_EQ(**results[1], 42); } -TEST(RendezvousTest, RendezvousSingleFlag) { - RendezvousSingleFlag flag; +TEST(RendezvousTest, RendezvousFlag) { + RendezvousFlag flag; auto thread_pool = CreateThreadPool(2); int32_t num_executed = 0; @@ -146,7 +145,7 @@ TEST(RendezvousTest, RendezvousSingleFlag) { auto task = [&](absl::BlockingCounter& counter) { return [&] { - RendezvousSingle( + Rendezvous( flag, "rendezvous_test", 0, 2, [&] { return ++num_executed; }, Timeout(), Terminate()); counter.DecrementCount(); @@ -169,8 +168,8 @@ TEST(RendezvousTest, RendezvousSingleFlag) { ASSERT_EQ(num_executed, 1); } -TEST(RendezvousTest, RendezvousSingleFlagRace) { - RendezvousSingleFlag flag; +TEST(RendezvousTest, RendezvousFlagRace) { + RendezvousFlag flag; static constexpr int32_t kNumRendezvous = 16; static constexpr int32_t kNumThreads = 8; @@ -179,8 +178,8 @@ TEST(RendezvousTest, RendezvousSingleFlagRace) { auto task = [&](int32_t key) { return [&, key] { - RendezvousSingle(flag, "key: " + std::to_string(key), key, kNumThreads, - Timeout(), Terminate()); + Rendezvous(flag, "key: " + std::to_string(key), key, kNumThreads, + Timeout(), Terminate()); }; }; @@ -191,8 +190,8 @@ TEST(RendezvousTest, RendezvousSingleFlagRace) { } } -TEST(RendezvousTest, RendezvousSingleFlagRaceWithBarriers) { - RendezvousSingleFlag flag; +TEST(RendezvousTest, RendezvousFlagRaceWithBarriers) { + RendezvousFlag flag; static constexpr int32_t kNumRendezvous = 16; static constexpr int32_t kNumThreads = 8; @@ -209,8 +208,8 @@ TEST(RendezvousTest, RendezvousSingleFlagRaceWithBarriers) { return [&, key] { participants_ready.DecrementCount(); participants_notification.WaitForNotification(); - RendezvousSingle(flag, "key: " + std::to_string(key), key, kNumThreads, - Timeout(), Terminate()); + Rendezvous(flag, "key: " + std::to_string(key), key, kNumThreads, + Timeout(), Terminate()); participants_done.DecrementCount(); }; }; @@ -238,8 +237,8 @@ static void BM_Rendezvous(benchmark::State& state) { absl::BlockingCounter counter(num_threads); for (int64_t i = 0; i < num_threads; ++i) { thread_pool.Schedule([&] { - RendezvousSingle("rendezvous_test", 0, num_threads, - [] { return 42; }); + Rendezvous("rendezvous_test", 0, num_threads, + [] { return 42; }); counter.DecrementCount(); }); } @@ -256,8 +255,8 @@ static void BM_RendezvousWithValues(benchmark::State& state) { for (int64_t i = 0; i < num_threads; ++i) { thread_pool.Schedule([&] { int32_t value = i; - RendezvousSingle("rendezvous_test", 0, value, num_threads, - [](auto) { return 42; }); + Rendezvous("rendezvous_test", 0, value, num_threads, + [](auto) { return 42; }); counter.DecrementCount(); }); } diff --git a/xla/xla.proto b/xla/xla.proto index 448cc49c9d9e7f..b13f2ca9b54621 100644 --- a/xla/xla.proto +++ b/xla/xla.proto @@ -1052,7 +1052,7 @@ message DebugOptions { AUTOTUNE_CACHE_MODE_READ = 2; } - // Timeouts for RendezvousSingle stuck warning and termination. + // Timeouts for Rendezvous stuck warning and termination. int32 xla_gpu_executable_warn_stuck_timeout_seconds = 327; int32 xla_gpu_executable_terminate_timeout_seconds = 328;