Skip to content

Commit

Permalink
[xla] Rename RendezvousSingle to Rendezvous
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 713515129
  • Loading branch information
ezhulenev authored and Google-ML-Automation committed Jan 10, 2025
1 parent c99a4f9 commit 2ff173c
Show file tree
Hide file tree
Showing 10 changed files with 88 additions and 97 deletions.
10 changes: 5 additions & 5 deletions xla/backends/cpu/collectives/in_process_communicator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ absl::Status InProcessCommunicator::AllReduce(se::DeviceMemoryBase send_buffer,
std::string name = absl::StrCat("all reduce ", key.ToString());
AllReduceParticipant partiticipant{rank_, send_buffer, recv_buffer};

return RendezvousSingle<absl::Status>(
return Rendezvous<absl::Status>(
name, key, partiticipant, key.num_local_participants,
std::bind(AllReduceOp, dtype, count, reduction_kind,
std::placeholders::_1));
Expand All @@ -362,7 +362,7 @@ absl::Status InProcessCommunicator::CollectivePermute(
recv_buffer};

size_t num_bytes = count * primitive_util::ByteWidth(dtype);
return RendezvousSingle<absl::Status>(
return Rendezvous<absl::Status>(
name, key, partiticipant, key.num_local_participants,
std::bind(CollectivePermuteOp, num_bytes, std::placeholders::_1));
}
Expand All @@ -380,7 +380,7 @@ absl::Status InProcessCommunicator::AllToAll(
{recv_buffers.begin(), recv_buffers.end()}};

size_t num_bytes = count * primitive_util::ByteWidth(dtype);
return RendezvousSingle<absl::Status>(
return Rendezvous<absl::Status>(
name, key, partiticipant, key.num_local_participants,
std::bind(AllToAllOp, num_bytes, std::placeholders::_1));
}
Expand All @@ -396,7 +396,7 @@ absl::Status InProcessCommunicator::AllGather(se::DeviceMemoryBase send_buffer,
AllGatherParticipant partiticipant{rank_, send_buffer, recv_buffer};

size_t num_bytes = count * primitive_util::ByteWidth(dtype);
return RendezvousSingle<absl::Status>(
return Rendezvous<absl::Status>(
name, key, partiticipant, key.num_local_participants,
std::bind(AllGatherOp, num_bytes, std::placeholders::_1));
}
Expand All @@ -411,7 +411,7 @@ absl::Status InProcessCommunicator::ReduceScatter(
std::string name = absl::StrCat("reduce scatter ", key.ToString());
ReduceScatterParticipant partiticipant{rank_, send_buffer, recv_buffer};

return RendezvousSingle<absl::Status>(
return Rendezvous<absl::Status>(
name, key, partiticipant, key.num_local_participants,
std::bind(ReduceScatterOp, dtype, count, reduction_kind,
std::placeholders::_1));
Expand Down
6 changes: 3 additions & 3 deletions xla/backends/gpu/collectives/gpu_cliques.cc
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ InitializeGpuClique(GpuCollectives* collectives, se::StreamExecutor* device,
// processes are not able to synchronize device activity.
RendezvousArg rendezvous_arg = std::make_pair(device_rank, synchronized);

return RendezvousSingle<absl::StatusOr<LockableGpuClique::Lock>>(
return Rendezvous<absl::StatusOr<LockableGpuClique::Lock>>(
initialization_rendezvous_name, rendezvous_key, rendezvous_arg,
num_local_participants, initialize, WarnStuckTimeout(),
TerminateTimeout());
Expand Down Expand Up @@ -431,7 +431,7 @@ InitializeGpuClique(GpuCollectives* collectives, se::StreamExecutor* device,
rank.value(), clique_key.ToString(), run_id.ToInt(),
parent_clique_key.ToString());

return RendezvousSingle<absl::StatusOr<LockableGpuClique::Lock>>(
return Rendezvous<absl::StatusOr<LockableGpuClique::Lock>>(
initialization_rendezvous_name, rendezvous_key, rank_pair,
num_local_participants, split, WarnStuckTimeout(), TerminateTimeout());
}
Expand Down Expand Up @@ -466,7 +466,7 @@ absl::StatusOr<std::shared_ptr<LockableGpuClique::Lock>> AcquireGpuClique(

TF_ASSIGN_OR_RETURN(
std::shared_ptr<LockableGpuClique::Lock> clique,
RendezvousSingle<absl::StatusOr<LockableGpuClique::Lock>>(
Rendezvous<absl::StatusOr<LockableGpuClique::Lock>>(
rendezvous_name, rendezvous_key, num_local_participants,
[&] {
tsl::profiler::TraceMe trace("LockGpuClique");
Expand Down
4 changes: 2 additions & 2 deletions xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2124,13 +2124,13 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
int32_setter_for(
&DebugOptions::set_xla_gpu_executable_warn_stuck_timeout_seconds),
debug_options->xla_gpu_executable_warn_stuck_timeout_seconds(),
"Set timeout for RendezvousSingle stuck warning"));
"Set timeout for Rendezvous stuck warning"));
flag_list->push_back(tsl::Flag(
"xla_gpu_executable_terminate_timeout",
int32_setter_for(
&DebugOptions::set_xla_gpu_executable_terminate_timeout_seconds),
debug_options->xla_gpu_executable_terminate_timeout_seconds(),
"Set timeout for RendezvousSingle termination"));
"Set timeout for Rendezvous termination"));
flag_list->push_back(tsl::Flag(
"xla_gpu_experimental_disable_binary_libraries",
bool_setter_for(
Expand Down
2 changes: 1 addition & 1 deletion xla/service/gpu/gpu_executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,7 @@ absl::Status RendezvousAfterInitialization(
run_options->device_ordinal(),
run_options->run_options().run_id().ToInt());

RendezvousSingle(
Rendezvous(
rendezvous_name, rendezvous_key, num_local_participants,
absl::Seconds(
debug_options
Expand Down
8 changes: 4 additions & 4 deletions xla/service/gpu/runtime/nccl_collective_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -479,10 +479,10 @@ absl::Status NcclCollectiveThunk::ExecuteOnStream(const ExecuteParams& params) {
"first call to collective operation %d; run_id=%d", config().op_id,
params.collective_params->run_id.ToInt());

RendezvousSingle(first_call_rendezvous_flag_, rendezvous_name,
rendezvous_key, num_local_participants,
/*warn_stuck_timeout=*/absl::Seconds(20),
/*terminate_timeout=*/absl::Seconds(40));
Rendezvous(first_call_rendezvous_flag_, rendezvous_name, rendezvous_key,
num_local_participants,
/*warn_stuck_timeout=*/absl::Seconds(20),
/*terminate_timeout=*/absl::Seconds(40));
}

return absl::OkStatus();
Expand Down
2 changes: 1 addition & 1 deletion xla/service/gpu/runtime/nccl_collective_thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ class NcclCollectiveThunk : public Thunk {
//
// TODO(ezhulenev): Try to move this flag to NCCL clique as we need to make
// sure that all NCCL resources are allocated just once.
RendezvousSingleFlag first_call_rendezvous_flag_;
RendezvousFlag first_call_rendezvous_flag_;
};

//===----------------------------------------------------------------------===//
Expand Down
15 changes: 6 additions & 9 deletions xla/service/rendezvous.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,13 +137,12 @@ inline constexpr int32_t kPending = 0;
inline constexpr int32_t kCompleted = std::numeric_limits<int32_t>::max();
} // namespace

RendezvousSingleFlag::RendezvousSingleFlag() : state_(kPending) {}
RendezvousFlag::RendezvousFlag() : state_(kPending) {}

RendezvousSingleFlag::InFlightRendezvous::InFlightRendezvous(
RendezvousSingleFlag* flag)
RendezvousFlag::InFlightRendezvous::InFlightRendezvous(RendezvousFlag* flag)
: flag_(flag) {}

RendezvousSingleFlag::InFlightRendezvous::~InFlightRendezvous() {
RendezvousFlag::InFlightRendezvous::~InFlightRendezvous() {
if (flag_ == nullptr) return;

// Reload state and use CAS to decide if we are the one who
Expand All @@ -162,11 +161,11 @@ RendezvousSingleFlag::InFlightRendezvous::~InFlightRendezvous() {
}
}

RendezvousSingleFlag::InFlightRendezvous::operator bool() const {
RendezvousFlag::InFlightRendezvous::operator bool() const {
return flag_ != nullptr;
}

RendezvousSingleFlag::InFlightRendezvous RendezvousSingleFlag::TryJoin() {
RendezvousFlag::InFlightRendezvous RendezvousFlag::TryJoin() {
// If `state_` is `kCompleted` it means that we have at least one completed
// rendezvous for this flag and can skip it.
if (state_.load() == kCompleted) return InFlightRendezvous(nullptr);
Expand All @@ -184,8 +183,6 @@ RendezvousSingleFlag::InFlightRendezvous RendezvousSingleFlag::TryJoin() {
return InFlightRendezvous(this);
}

bool RendezvousSingleFlag::IsCompleted() const {
return state_.load() == kCompleted;
}
bool RendezvousFlag::IsCompleted() const { return state_.load() == kCompleted; }

} // namespace xla
93 changes: 44 additions & 49 deletions xla/service/rendezvous.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,29 +85,28 @@ using RendezvousResultType = typename RendezvousResult<R>::Type;
// all threads receive the result. Rendezvous must have a human readable name to
// make easy to debug stuck and timed out attempts.
template <typename R, typename K, typename V, typename Fn>
RendezvousResultType<R> RendezvousSingle(
RendezvousResultType<R> Rendezvous(
absl::string_view name, const K& key, const V& value, size_t num_threads,
Fn fn, absl::Duration warn_stuck_timeout = absl::InfiniteDuration(),
absl::Duration terminate_timeout = absl::InfiniteDuration());

// A rendezvous for a group of threads that do not have any value arguments.
template <typename R, typename K, typename Fn>
RendezvousResultType<R> RendezvousSingle(
RendezvousResultType<R> Rendezvous(
absl::string_view name, const K& key, size_t num_threads, Fn fn,
absl::Duration warn_stuck_timeout = absl::InfiniteDuration(),
absl::Duration terminate_timeout = absl::InfiniteDuration());

// A rendezvous for a group of threads that do not have any computation to run
// and simply acts as a barrier for a group of thread.
template <typename K>
void RendezvousSingle(
absl::string_view name, const K& key, size_t num_threads,
absl::Duration warn_stuck_timeout = absl::InfiniteDuration(),
absl::Duration terminate_timeout = absl::InfiniteDuration());
void Rendezvous(absl::string_view name, const K& key, size_t num_threads,
absl::Duration warn_stuck_timeout = absl::InfiniteDuration(),
absl::Duration terminate_timeout = absl::InfiniteDuration());

// An `std::once_flag`-like primitive for executing RendezvousSingle operations.
// An `std::once_flag`-like primitive for executing Rendezvous operations.
//
// RendezvousSingleFlag guarantees that all or none participants in a rendezvous
// RendezvousFlag guarantees that all or none participants in a rendezvous
// join the rendezvous process and once rendezvous is completed flag marked as
// `completed` and all further rendezvous using this flag will be skipped. It
// has a weaker than exactly-once guarantee and multiple racing rendezvous can
Expand All @@ -119,17 +118,17 @@ void RendezvousSingle(
// and prefer simpler implementation with weaker guarantees.
//
// See: https://en.cppreference.com/w/cpp/thread/once_flag
class RendezvousSingleFlag {
class RendezvousFlag {
public:
RendezvousSingleFlag();
RendezvousFlag();

RendezvousSingleFlag(const RendezvousSingleFlag&) = delete;
RendezvousSingleFlag& operator=(const RendezvousSingleFlag&) = delete;
RendezvousFlag(const RendezvousFlag&) = delete;
RendezvousFlag& operator=(const RendezvousFlag&) = delete;

// RAII wrapper to exit from in-flight rendezvous when destructed.
class InFlightRendezvous {
public:
explicit InFlightRendezvous(RendezvousSingleFlag* flag);
explicit InFlightRendezvous(RendezvousFlag* flag);
~InFlightRendezvous();

InFlightRendezvous(const InFlightRendezvous&) = delete;
Expand All @@ -138,7 +137,7 @@ class RendezvousSingleFlag {
operator bool() const; // NOLINT

private:
RendezvousSingleFlag* flag_;
RendezvousFlag* flag_;
};

// Returns InFlightRendezvous convertible to `true` if the caller should join
Expand All @@ -159,8 +158,8 @@ class RendezvousSingleFlag {
// rendezvous. If rendezvous will not be executed it will return empty shared
// pointer result.
template <typename R, typename K, typename Fn>
RendezvousResultType<R> RendezvousSingle(
RendezvousSingleFlag& flag, absl::string_view name, const K& key,
RendezvousResultType<R> Rendezvous(
RendezvousFlag& flag, absl::string_view name, const K& key,
size_t num_threads, Fn fn,
absl::Duration warn_stuck_timeout = absl::InfiniteDuration(),
absl::Duration terminate_timeout = absl::InfiniteDuration());
Expand All @@ -169,11 +168,10 @@ RendezvousResultType<R> RendezvousSingle(
// not in `completed` state and will switch it to `completed` after finishing a
// rendezvous.
template <typename K>
void RendezvousSingle(
RendezvousSingleFlag& flag, absl::string_view name, const K& key,
size_t num_threads,
absl::Duration warn_stuck_timeout = absl::InfiniteDuration(),
absl::Duration terminate_timeout = absl::InfiniteDuration());
void Rendezvous(RendezvousFlag& flag, absl::string_view name, const K& key,
size_t num_threads,
absl::Duration warn_stuck_timeout = absl::InfiniteDuration(),
absl::Duration terminate_timeout = absl::InfiniteDuration());

//===----------------------------------------------------------------------===//
// Internal implementation details.
Expand Down Expand Up @@ -291,11 +289,10 @@ void AwaitAndLogIfStuck(RendezvousStateSynchronization& state, int32_t id,
//===----------------------------------------------------------------------===//

template <typename R, typename K, typename V, typename Fn>
RendezvousResultType<R> RendezvousSingle(absl::string_view name, const K& key,
const V& value, size_t num_threads,
Fn fn,
absl::Duration warn_stuck_timeout,
absl::Duration terminate_timeout) {
RendezvousResultType<R> Rendezvous(absl::string_view name, const K& key,
const V& value, size_t num_threads, Fn fn,
absl::Duration warn_stuck_timeout,
absl::Duration terminate_timeout) {
// Check that `fn` is callable with a span of values and returns `R`.
static_assert(std::is_invocable_r_v<R, Fn, absl::Span<const V*>>,
"invalid rendezvous function signature");
Expand All @@ -319,7 +316,7 @@ RendezvousResultType<R> RendezvousSingle(absl::string_view name, const K& key,

tsl::profiler::TraceMe trace([&] {
return tsl::profiler::TraceMeEncode(
"RendezvousSingle",
"Rendezvous",
{{"num_threads", num_threads}, {"name", name}, {"id", id}});
});

Expand Down Expand Up @@ -355,46 +352,44 @@ RendezvousResultType<R> RendezvousSingle(absl::string_view name, const K& key,
}

template <typename R, typename K, typename Fn>
RendezvousResultType<R> RendezvousSingle(absl::string_view name, const K& key,
size_t num_threads, Fn fn,
absl::Duration warn_stuck_timeout,
absl::Duration terminate_timeout) {
return RendezvousSingle<R, K, std::nullopt_t>(
RendezvousResultType<R> Rendezvous(absl::string_view name, const K& key,
size_t num_threads, Fn fn,
absl::Duration warn_stuck_timeout,
absl::Duration terminate_timeout) {
return Rendezvous<R, K, std::nullopt_t>(
name, key, std::nullopt, num_threads, [fn](auto) { return fn(); },
warn_stuck_timeout, terminate_timeout);
}

template <typename K>
void RendezvousSingle(absl::string_view name, const K& key, size_t num_threads,
absl::Duration warn_stuck_timeout,
absl::Duration terminate_timeout) {
RendezvousSingle<std::nullopt_t, K, std::nullopt_t>(
void Rendezvous(absl::string_view name, const K& key, size_t num_threads,
absl::Duration warn_stuck_timeout,
absl::Duration terminate_timeout) {
Rendezvous<std::nullopt_t, K, std::nullopt_t>(
name, key, std::nullopt, num_threads, [](auto) { return std::nullopt; },
warn_stuck_timeout, terminate_timeout);
}

template <typename R, typename K, typename Fn>
RendezvousResultType<R> RendezvousSingle(RendezvousSingleFlag& flag,
absl::string_view name, const K& key,
size_t num_threads, Fn fn,
absl::Duration warn_stuck_timeout,
absl::Duration terminate_timeout) {
RendezvousResultType<R> Rendezvous(RendezvousFlag& flag, absl::string_view name,
const K& key, size_t num_threads, Fn fn,
absl::Duration warn_stuck_timeout,
absl::Duration terminate_timeout) {
if (auto in_flight_rendezvous = flag.TryJoin()) {
return RendezvousSingle<K>(name, key, num_threads, std::move(fn),
warn_stuck_timeout, terminate_timeout);
return Rendezvous<K>(name, key, num_threads, std::move(fn),
warn_stuck_timeout, terminate_timeout);
} else {
return RendezvousResult<R>::Empty();
}
}

template <typename K>
void RendezvousSingle(RendezvousSingleFlag& flag, absl::string_view name,
const K& key, size_t num_threads,
absl::Duration warn_stuck_timeout,
absl::Duration terminate_timeout) {
void Rendezvous(RendezvousFlag& flag, absl::string_view name, const K& key,
size_t num_threads, absl::Duration warn_stuck_timeout,
absl::Duration terminate_timeout) {
if (auto in_flight_rendezvous = flag.TryJoin()) {
RendezvousSingle<K>(name, key, num_threads, warn_stuck_timeout,
terminate_timeout);
Rendezvous<K>(name, key, num_threads, warn_stuck_timeout,
terminate_timeout);
}
}

Expand Down
Loading

0 comments on commit 2ff173c

Please sign in to comment.