Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA:CPU] Remove unused stream_executor host code. #21199

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions xla/backends/cpu/collectives/in_process_communicator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ absl::Status InProcessCommunicator::AllReduce(se::DeviceMemoryBase send_buffer,
std::string name = absl::StrCat("all reduce ", key.ToString());
AllReduceParticipant partiticipant{rank_, send_buffer, recv_buffer};

return RendezvousSingle<absl::Status>(
return Rendezvous<absl::Status>(
name, key, partiticipant, key.num_local_participants,
std::bind(AllReduceOp, dtype, count, reduction_kind,
std::placeholders::_1));
Expand All @@ -362,7 +362,7 @@ absl::Status InProcessCommunicator::CollectivePermute(
recv_buffer};

size_t num_bytes = count * primitive_util::ByteWidth(dtype);
return RendezvousSingle<absl::Status>(
return Rendezvous<absl::Status>(
name, key, partiticipant, key.num_local_participants,
std::bind(CollectivePermuteOp, num_bytes, std::placeholders::_1));
}
Expand All @@ -380,7 +380,7 @@ absl::Status InProcessCommunicator::AllToAll(
{recv_buffers.begin(), recv_buffers.end()}};

size_t num_bytes = count * primitive_util::ByteWidth(dtype);
return RendezvousSingle<absl::Status>(
return Rendezvous<absl::Status>(
name, key, partiticipant, key.num_local_participants,
std::bind(AllToAllOp, num_bytes, std::placeholders::_1));
}
Expand All @@ -396,7 +396,7 @@ absl::Status InProcessCommunicator::AllGather(se::DeviceMemoryBase send_buffer,
AllGatherParticipant partiticipant{rank_, send_buffer, recv_buffer};

size_t num_bytes = count * primitive_util::ByteWidth(dtype);
return RendezvousSingle<absl::Status>(
return Rendezvous<absl::Status>(
name, key, partiticipant, key.num_local_participants,
std::bind(AllGatherOp, num_bytes, std::placeholders::_1));
}
Expand All @@ -411,7 +411,7 @@ absl::Status InProcessCommunicator::ReduceScatter(
std::string name = absl::StrCat("reduce scatter ", key.ToString());
ReduceScatterParticipant partiticipant{rank_, send_buffer, recv_buffer};

return RendezvousSingle<absl::Status>(
return Rendezvous<absl::Status>(
name, key, partiticipant, key.num_local_participants,
std::bind(ReduceScatterOp, dtype, count, reduction_kind,
std::placeholders::_1));
Expand Down
6 changes: 3 additions & 3 deletions xla/backends/gpu/collectives/gpu_cliques.cc
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ InitializeGpuClique(GpuCollectives* collectives, se::StreamExecutor* device,
// processes are not able to synchronize device activity.
RendezvousArg rendezvous_arg = std::make_pair(device_rank, synchronized);

return RendezvousSingle<absl::StatusOr<LockableGpuClique::Lock>>(
return Rendezvous<absl::StatusOr<LockableGpuClique::Lock>>(
initialization_rendezvous_name, rendezvous_key, rendezvous_arg,
num_local_participants, initialize, WarnStuckTimeout(),
TerminateTimeout());
Expand Down Expand Up @@ -431,7 +431,7 @@ InitializeGpuClique(GpuCollectives* collectives, se::StreamExecutor* device,
rank.value(), clique_key.ToString(), run_id.ToInt(),
parent_clique_key.ToString());

return RendezvousSingle<absl::StatusOr<LockableGpuClique::Lock>>(
return Rendezvous<absl::StatusOr<LockableGpuClique::Lock>>(
initialization_rendezvous_name, rendezvous_key, rank_pair,
num_local_participants, split, WarnStuckTimeout(), TerminateTimeout());
}
Expand Down Expand Up @@ -466,7 +466,7 @@ absl::StatusOr<std::shared_ptr<LockableGpuClique::Lock>> AcquireGpuClique(

TF_ASSIGN_OR_RETURN(
std::shared_ptr<LockableGpuClique::Lock> clique,
RendezvousSingle<absl::StatusOr<LockableGpuClique::Lock>>(
Rendezvous<absl::StatusOr<LockableGpuClique::Lock>>(
rendezvous_name, rendezvous_key, num_local_participants,
[&] {
tsl::profiler::TraceMe trace("LockGpuClique");
Expand Down
4 changes: 2 additions & 2 deletions xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2124,13 +2124,13 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
int32_setter_for(
&DebugOptions::set_xla_gpu_executable_warn_stuck_timeout_seconds),
debug_options->xla_gpu_executable_warn_stuck_timeout_seconds(),
"Set timeout for RendezvousSingle stuck warning"));
"Set timeout for Rendezvous stuck warning"));
flag_list->push_back(tsl::Flag(
"xla_gpu_executable_terminate_timeout",
int32_setter_for(
&DebugOptions::set_xla_gpu_executable_terminate_timeout_seconds),
debug_options->xla_gpu_executable_terminate_timeout_seconds(),
"Set timeout for RendezvousSingle termination"));
"Set timeout for Rendezvous termination"));
flag_list->push_back(tsl::Flag(
"xla_gpu_experimental_disable_binary_libraries",
bool_setter_for(
Expand Down
2 changes: 1 addition & 1 deletion xla/service/gpu/gpu_executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,7 @@ absl::Status RendezvousAfterInitialization(
run_options->device_ordinal(),
run_options->run_options().run_id().ToInt());

RendezvousSingle(
Rendezvous(
rendezvous_name, rendezvous_key, num_local_participants,
absl::Seconds(
debug_options
Expand Down
8 changes: 4 additions & 4 deletions xla/service/gpu/runtime/nccl_collective_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -479,10 +479,10 @@ absl::Status NcclCollectiveThunk::ExecuteOnStream(const ExecuteParams& params) {
"first call to collective operation %d; run_id=%d", config().op_id,
params.collective_params->run_id.ToInt());

RendezvousSingle(first_call_rendezvous_flag_, rendezvous_name,
rendezvous_key, num_local_participants,
/*warn_stuck_timeout=*/absl::Seconds(20),
/*terminate_timeout=*/absl::Seconds(40));
Rendezvous(first_call_rendezvous_flag_, rendezvous_name, rendezvous_key,
num_local_participants,
/*warn_stuck_timeout=*/absl::Seconds(20),
/*terminate_timeout=*/absl::Seconds(40));
}

return absl::OkStatus();
Expand Down
2 changes: 1 addition & 1 deletion xla/service/gpu/runtime/nccl_collective_thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ class NcclCollectiveThunk : public Thunk {
//
// TODO(ezhulenev): Try to move this flag to NCCL clique as we need to make
// sure that all NCCL resources are allocated just once.
RendezvousSingleFlag first_call_rendezvous_flag_;
RendezvousFlag first_call_rendezvous_flag_;
};

//===----------------------------------------------------------------------===//
Expand Down
15 changes: 6 additions & 9 deletions xla/service/rendezvous.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,13 +137,12 @@ inline constexpr int32_t kPending = 0;
inline constexpr int32_t kCompleted = std::numeric_limits<int32_t>::max();
} // namespace

RendezvousSingleFlag::RendezvousSingleFlag() : state_(kPending) {}
RendezvousFlag::RendezvousFlag() : state_(kPending) {}

RendezvousSingleFlag::InFlightRendezvous::InFlightRendezvous(
RendezvousSingleFlag* flag)
RendezvousFlag::InFlightRendezvous::InFlightRendezvous(RendezvousFlag* flag)
: flag_(flag) {}

RendezvousSingleFlag::InFlightRendezvous::~InFlightRendezvous() {
RendezvousFlag::InFlightRendezvous::~InFlightRendezvous() {
if (flag_ == nullptr) return;

// Reload state and use CAS to decide if we are the one who
Expand All @@ -162,11 +161,11 @@ RendezvousSingleFlag::InFlightRendezvous::~InFlightRendezvous() {
}
}

RendezvousSingleFlag::InFlightRendezvous::operator bool() const {
RendezvousFlag::InFlightRendezvous::operator bool() const {
return flag_ != nullptr;
}

RendezvousSingleFlag::InFlightRendezvous RendezvousSingleFlag::TryJoin() {
RendezvousFlag::InFlightRendezvous RendezvousFlag::TryJoin() {
// If `state_` is `kCompleted` it means that we have at least one completed
// rendezvous for this flag and can skip it.
if (state_.load() == kCompleted) return InFlightRendezvous(nullptr);
Expand All @@ -184,8 +183,6 @@ RendezvousSingleFlag::InFlightRendezvous RendezvousSingleFlag::TryJoin() {
return InFlightRendezvous(this);
}

bool RendezvousSingleFlag::IsCompleted() const {
return state_.load() == kCompleted;
}
bool RendezvousFlag::IsCompleted() const { return state_.load() == kCompleted; }

} // namespace xla
93 changes: 44 additions & 49 deletions xla/service/rendezvous.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,29 +85,28 @@ using RendezvousResultType = typename RendezvousResult<R>::Type;
// all threads receive the result. Rendezvous must have a human readable name to
// make easy to debug stuck and timed out attempts.
template <typename R, typename K, typename V, typename Fn>
RendezvousResultType<R> RendezvousSingle(
RendezvousResultType<R> Rendezvous(
absl::string_view name, const K& key, const V& value, size_t num_threads,
Fn fn, absl::Duration warn_stuck_timeout = absl::InfiniteDuration(),
absl::Duration terminate_timeout = absl::InfiniteDuration());

// A rendezvous for a group of threads that do not have any value arguments.
template <typename R, typename K, typename Fn>
RendezvousResultType<R> RendezvousSingle(
RendezvousResultType<R> Rendezvous(
absl::string_view name, const K& key, size_t num_threads, Fn fn,
absl::Duration warn_stuck_timeout = absl::InfiniteDuration(),
absl::Duration terminate_timeout = absl::InfiniteDuration());

// A rendezvous for a group of threads that do not have any computation to run
// and simply acts as a barrier for a group of thread.
template <typename K>
void RendezvousSingle(
absl::string_view name, const K& key, size_t num_threads,
absl::Duration warn_stuck_timeout = absl::InfiniteDuration(),
absl::Duration terminate_timeout = absl::InfiniteDuration());
void Rendezvous(absl::string_view name, const K& key, size_t num_threads,
absl::Duration warn_stuck_timeout = absl::InfiniteDuration(),
absl::Duration terminate_timeout = absl::InfiniteDuration());

// An `std::once_flag`-like primitive for executing RendezvousSingle operations.
// An `std::once_flag`-like primitive for executing Rendezvous operations.
//
// RendezvousSingleFlag guarantees that all or none participants in a rendezvous
// RendezvousFlag guarantees that all or none participants in a rendezvous
// join the rendezvous process and once rendezvous is completed flag marked as
// `completed` and all further rendezvous using this flag will be skipped. It
// has a weaker than exactly-once guarantee and multiple racing rendezvous can
Expand All @@ -119,17 +118,17 @@ void RendezvousSingle(
// and prefer simpler implementation with weaker guarantees.
//
// See: https://en.cppreference.com/w/cpp/thread/once_flag
class RendezvousSingleFlag {
class RendezvousFlag {
public:
RendezvousSingleFlag();
RendezvousFlag();

RendezvousSingleFlag(const RendezvousSingleFlag&) = delete;
RendezvousSingleFlag& operator=(const RendezvousSingleFlag&) = delete;
RendezvousFlag(const RendezvousFlag&) = delete;
RendezvousFlag& operator=(const RendezvousFlag&) = delete;

// RAII wrapper to exit from in-flight rendezvous when destructed.
class InFlightRendezvous {
public:
explicit InFlightRendezvous(RendezvousSingleFlag* flag);
explicit InFlightRendezvous(RendezvousFlag* flag);
~InFlightRendezvous();

InFlightRendezvous(const InFlightRendezvous&) = delete;
Expand All @@ -138,7 +137,7 @@ class RendezvousSingleFlag {
operator bool() const; // NOLINT

private:
RendezvousSingleFlag* flag_;
RendezvousFlag* flag_;
};

// Returns InFlightRendezvous convertible to `true` if the caller should join
Expand All @@ -159,8 +158,8 @@ class RendezvousSingleFlag {
// rendezvous. If rendezvous will not be executed it will return empty shared
// pointer result.
template <typename R, typename K, typename Fn>
RendezvousResultType<R> RendezvousSingle(
RendezvousSingleFlag& flag, absl::string_view name, const K& key,
RendezvousResultType<R> Rendezvous(
RendezvousFlag& flag, absl::string_view name, const K& key,
size_t num_threads, Fn fn,
absl::Duration warn_stuck_timeout = absl::InfiniteDuration(),
absl::Duration terminate_timeout = absl::InfiniteDuration());
Expand All @@ -169,11 +168,10 @@ RendezvousResultType<R> RendezvousSingle(
// not in `completed` state and will switch it to `completed` after finishing a
// rendezvous.
template <typename K>
void RendezvousSingle(
RendezvousSingleFlag& flag, absl::string_view name, const K& key,
size_t num_threads,
absl::Duration warn_stuck_timeout = absl::InfiniteDuration(),
absl::Duration terminate_timeout = absl::InfiniteDuration());
void Rendezvous(RendezvousFlag& flag, absl::string_view name, const K& key,
size_t num_threads,
absl::Duration warn_stuck_timeout = absl::InfiniteDuration(),
absl::Duration terminate_timeout = absl::InfiniteDuration());

//===----------------------------------------------------------------------===//
// Internal implementation details.
Expand Down Expand Up @@ -291,11 +289,10 @@ void AwaitAndLogIfStuck(RendezvousStateSynchronization& state, int32_t id,
//===----------------------------------------------------------------------===//

template <typename R, typename K, typename V, typename Fn>
RendezvousResultType<R> RendezvousSingle(absl::string_view name, const K& key,
const V& value, size_t num_threads,
Fn fn,
absl::Duration warn_stuck_timeout,
absl::Duration terminate_timeout) {
RendezvousResultType<R> Rendezvous(absl::string_view name, const K& key,
const V& value, size_t num_threads, Fn fn,
absl::Duration warn_stuck_timeout,
absl::Duration terminate_timeout) {
// Check that `fn` is callable with a span of values and returns `R`.
static_assert(std::is_invocable_r_v<R, Fn, absl::Span<const V*>>,
"invalid rendezvous function signature");
Expand All @@ -319,7 +316,7 @@ RendezvousResultType<R> RendezvousSingle(absl::string_view name, const K& key,

tsl::profiler::TraceMe trace([&] {
return tsl::profiler::TraceMeEncode(
"RendezvousSingle",
"Rendezvous",
{{"num_threads", num_threads}, {"name", name}, {"id", id}});
});

Expand Down Expand Up @@ -355,46 +352,44 @@ RendezvousResultType<R> RendezvousSingle(absl::string_view name, const K& key,
}

template <typename R, typename K, typename Fn>
RendezvousResultType<R> RendezvousSingle(absl::string_view name, const K& key,
size_t num_threads, Fn fn,
absl::Duration warn_stuck_timeout,
absl::Duration terminate_timeout) {
return RendezvousSingle<R, K, std::nullopt_t>(
RendezvousResultType<R> Rendezvous(absl::string_view name, const K& key,
size_t num_threads, Fn fn,
absl::Duration warn_stuck_timeout,
absl::Duration terminate_timeout) {
return Rendezvous<R, K, std::nullopt_t>(
name, key, std::nullopt, num_threads, [fn](auto) { return fn(); },
warn_stuck_timeout, terminate_timeout);
}

template <typename K>
void RendezvousSingle(absl::string_view name, const K& key, size_t num_threads,
absl::Duration warn_stuck_timeout,
absl::Duration terminate_timeout) {
RendezvousSingle<std::nullopt_t, K, std::nullopt_t>(
void Rendezvous(absl::string_view name, const K& key, size_t num_threads,
absl::Duration warn_stuck_timeout,
absl::Duration terminate_timeout) {
Rendezvous<std::nullopt_t, K, std::nullopt_t>(
name, key, std::nullopt, num_threads, [](auto) { return std::nullopt; },
warn_stuck_timeout, terminate_timeout);
}

template <typename R, typename K, typename Fn>
RendezvousResultType<R> RendezvousSingle(RendezvousSingleFlag& flag,
absl::string_view name, const K& key,
size_t num_threads, Fn fn,
absl::Duration warn_stuck_timeout,
absl::Duration terminate_timeout) {
RendezvousResultType<R> Rendezvous(RendezvousFlag& flag, absl::string_view name,
const K& key, size_t num_threads, Fn fn,
absl::Duration warn_stuck_timeout,
absl::Duration terminate_timeout) {
if (auto in_flight_rendezvous = flag.TryJoin()) {
return RendezvousSingle<K>(name, key, num_threads, std::move(fn),
warn_stuck_timeout, terminate_timeout);
return Rendezvous<K>(name, key, num_threads, std::move(fn),
warn_stuck_timeout, terminate_timeout);
} else {
return RendezvousResult<R>::Empty();
}
}

template <typename K>
void RendezvousSingle(RendezvousSingleFlag& flag, absl::string_view name,
const K& key, size_t num_threads,
absl::Duration warn_stuck_timeout,
absl::Duration terminate_timeout) {
void Rendezvous(RendezvousFlag& flag, absl::string_view name, const K& key,
size_t num_threads, absl::Duration warn_stuck_timeout,
absl::Duration terminate_timeout) {
if (auto in_flight_rendezvous = flag.TryJoin()) {
RendezvousSingle<K>(name, key, num_threads, warn_stuck_timeout,
terminate_timeout);
Rendezvous<K>(name, key, num_threads, warn_stuck_timeout,
terminate_timeout);
}
}

Expand Down
Loading