From 40286abb057803dade1a219dc2f796c60fdd1e0f Mon Sep 17 00:00:00 2001 From: Nicolas Castet Date: Fri, 10 Jan 2025 10:11:20 -0600 Subject: [PATCH] [XLA:GPU] Add support for NCCL ncclCommInitRankScalable API --- .../cpu/collectives/gloo_collectives.cc | 2 +- .../cpu/collectives/gloo_collectives.h | 2 +- .../cpu/collectives/in_process_collectives.cc | 2 +- .../cpu/collectives/in_process_collectives.h | 2 +- .../cpu/collectives/mpi_collectives.cc | 2 +- .../cpu/collectives/mpi_collectives.h | 2 +- xla/backends/gpu/collectives/gpu_clique.cc | 10 +++--- xla/backends/gpu/collectives/gpu_clique.h | 8 ++--- .../gpu/collectives/gpu_clique_key.cc | 36 ++++++++++++++++--- xla/backends/gpu/collectives/gpu_clique_key.h | 9 +++++ .../gpu/collectives/gpu_clique_locking.cc | 30 ++++++++++++---- .../gpu/collectives/gpu_collectives_stub.h | 2 +- .../gpu/collectives/nccl_collectives.cc | 29 +++++++++++---- .../gpu/collectives/nccl_collectives.h | 2 +- xla/core/collectives/clique_id.cc | 16 +++++++++ xla/core/collectives/clique_id.h | 25 +++++++++++++ xla/core/collectives/collectives.h | 2 +- xla/debug_options_flags.cc | 9 +++++ xla/pjrt/gpu/nccl_id_store.cc | 2 +- xla/tsl/cuda/nccl.symbols | 2 ++ xla/xla.proto | 3 ++ 21 files changed, 159 insertions(+), 38 deletions(-) diff --git a/xla/backends/cpu/collectives/gloo_collectives.cc b/xla/backends/cpu/collectives/gloo_collectives.cc index 5880704f3c680c..31fa9314783550 100644 --- a/xla/backends/cpu/collectives/gloo_collectives.cc +++ b/xla/backends/cpu/collectives/gloo_collectives.cc @@ -54,7 +54,7 @@ GlooCollectives::~GlooCollectives() = default; absl::StatusOr>> GlooCollectives::CreateCommunicators(int32_t nranks, const CliqueKey& clique_key, - const std::optional& clique_id, + const std::optional& clique_ids, absl::Span ranks, const Config& config) { std::vector> communicators; diff --git a/xla/backends/cpu/collectives/gloo_collectives.h b/xla/backends/cpu/collectives/gloo_collectives.h index 740e8ddc8bc215..3a93b2ab6f6df2 100644 --- a/xla/backends/cpu/collectives/gloo_collectives.h +++ b/xla/backends/cpu/collectives/gloo_collectives.h @@ -42,7 +42,7 @@ class GlooCollectives : public CpuCollectives { absl::StatusOr>> CreateCommunicators(int32_t nranks, const CliqueKey& clique_key, - const std::optional& clique_id, + const std::optional& clique_ids, absl::Span ranks, const Config& config) final; diff --git a/xla/backends/cpu/collectives/in_process_collectives.cc b/xla/backends/cpu/collectives/in_process_collectives.cc index 80227ad7550cc2..d427f524111c7c 100644 --- a/xla/backends/cpu/collectives/in_process_collectives.cc +++ b/xla/backends/cpu/collectives/in_process_collectives.cc @@ -36,7 +36,7 @@ namespace xla::cpu { absl::StatusOr>> InProcessCollectives::CreateCommunicators( int32_t nranks, const CliqueKey& clique_key, - const std::optional& clique_id, + const std::optional& clique_ids, absl::Span ranks, const Config& config) { absl::MutexLock lock(&mu_); diff --git a/xla/backends/cpu/collectives/in_process_collectives.h b/xla/backends/cpu/collectives/in_process_collectives.h index 2fd5e53afcf320..2cdc1160432250 100644 --- a/xla/backends/cpu/collectives/in_process_collectives.h +++ b/xla/backends/cpu/collectives/in_process_collectives.h @@ -38,7 +38,7 @@ class InProcessCollectives : public CpuCollectives { public: absl::StatusOr>> CreateCommunicators(int32_t nranks, const CliqueKey& clique_key, - const std::optional& clique_id, + const std::optional& clique_ids, absl::Span ranks, const Config& config) final; diff --git a/xla/backends/cpu/collectives/mpi_collectives.cc b/xla/backends/cpu/collectives/mpi_collectives.cc index 38b2dd1262b8d1..b6b679e802de7c 100644 --- a/xla/backends/cpu/collectives/mpi_collectives.cc +++ b/xla/backends/cpu/collectives/mpi_collectives.cc @@ -46,7 +46,7 @@ void MpiCollectives::Finalize() { MPI_Finalize(); } absl::StatusOr>> MpiCollectives::CreateCommunicators(int32_t nranks, const CliqueKey& clique_key, - const std::optional& clique_id, + const std::optional& clique_ids, absl::Span ranks, const Config& config) { int flag; diff --git a/xla/backends/cpu/collectives/mpi_collectives.h b/xla/backends/cpu/collectives/mpi_collectives.h index 82722b954121af..49e6fdfa9c63a5 100644 --- a/xla/backends/cpu/collectives/mpi_collectives.h +++ b/xla/backends/cpu/collectives/mpi_collectives.h @@ -49,7 +49,7 @@ class MpiCollectives : public CpuCollectives { absl::StatusOr>> CreateCommunicators(int32_t nranks, const CliqueKey& clique_key, - const std::optional& clique_id, + const std::optional& clique_ids, absl::Span ranks, const Config& config) final; diff --git a/xla/backends/gpu/collectives/gpu_clique.cc b/xla/backends/gpu/collectives/gpu_clique.cc index affc92419f5cc3..b0b751235a5233 100644 --- a/xla/backends/gpu/collectives/gpu_clique.cc +++ b/xla/backends/gpu/collectives/gpu_clique.cc @@ -37,14 +37,14 @@ limitations under the License. namespace xla::gpu { GpuClique::GpuClique( - GpuCliqueKey key, std::optional id, + GpuCliqueKey key, std::optional ids, absl::btree_map> communicators) - : Clique(std::move(communicators)), key_(key), id_(id) {} + : Clique(std::move(communicators)), key_(key), ids_(ids) {} std::string GpuClique::DebugString() const { std::string out = absl::StrFormat("key: %s; fingerprint(id): %d; size: %d; communicators: ", - key_.ToString(), id_.has_value() ? id_->fingerprint() : 0, + key_.ToString(), ids_.has_value() ? ids_->fingerprint() : 0, num_communicators()); int32_t cnt = 0; ForEachComm([&](RankId rank, Communicator* comm) { @@ -70,9 +70,9 @@ std::string GpuClique::LockableName::ToString(const GpuClique& clique) { } LockableGpuClique::LockableGpuClique( - GpuCliqueKey clique_key, std::optional clique_id, + GpuCliqueKey clique_key, std::optional clique_ids, absl::btree_map> communicators) - : Lockable(std::move(clique_key), clique_id, std::move(communicators)) {} + : Lockable(std::move(clique_key), clique_ids, std::move(communicators)) {} absl::Status LockableGpuClique::HealthCheck() const { return value().HealthCheck(); diff --git a/xla/backends/gpu/collectives/gpu_clique.h b/xla/backends/gpu/collectives/gpu_clique.h index 3a2a3500c54082..c5c48712d9edbc 100644 --- a/xla/backends/gpu/collectives/gpu_clique.h +++ b/xla/backends/gpu/collectives/gpu_clique.h @@ -40,7 +40,7 @@ class LockableGpuClique; class GpuClique : public Clique { public: GpuClique( - GpuCliqueKey key, std::optional id, + GpuCliqueKey key, std::optional ids, absl::btree_map> communicators); // Returns true if clique is local: all communicators belong to current @@ -48,7 +48,7 @@ class GpuClique : public Clique { bool IsLocal() const { return num_communicators() == key_.devices().size(); } const GpuCliqueKey& key() const { return key_; } - const std::optional& id() const { return id_; } + const std::optional& ids() const { return ids_; } std::string DebugString() const final; absl::Status HealthCheck() const final; @@ -62,7 +62,7 @@ class GpuClique : public Clique { }; GpuCliqueKey key_; - std::optional id_; + std::optional ids_; }; // A lockable version of GpuClique that guarantees exclusive access to the @@ -70,7 +70,7 @@ class GpuClique : public Clique { class LockableGpuClique : public Lockable { public: LockableGpuClique( - GpuCliqueKey clique_key, std::optional clique_id, + GpuCliqueKey clique_key, std::optional clique_ids, absl::btree_map> communicators); std::string DebugString() const; diff --git a/xla/backends/gpu/collectives/gpu_clique_key.cc b/xla/backends/gpu/collectives/gpu_clique_key.cc index 378ae084038b0d..decd4d46b71d4a 100644 --- a/xla/backends/gpu/collectives/gpu_clique_key.cc +++ b/xla/backends/gpu/collectives/gpu_clique_key.cc @@ -49,7 +49,8 @@ GpuCliqueKey::GpuCliqueKey( : CliqueKey(std::move(devices)), stream_id_(stream_id), stream_kind_(stream_kind), - participant_groups_(std::move(participant_groups)) { + participant_groups_(std::move(participant_groups)), + root_device_(-1) { for (std::vector& group : participant_groups_) { absl::c_sort(group); } @@ -65,6 +66,8 @@ GpuCliqueKey::GpuCliqueKey( CollectiveStreamId GpuCliqueKey::stream_id() const { return stream_id_; } +GlobalDeviceId GpuCliqueKey::root_device() const { return root_device_; } + bool GpuCliqueKey::IsSubsetOf(const CliqueKey& other) const { auto* other_nccl = tsl::down_cast(&other); if (other_nccl == nullptr) return false; @@ -75,6 +78,22 @@ bool GpuCliqueKey::IsSubsetOf(const CliqueKey& other) const { }); } +GpuCliqueKey GpuCliqueKey::GetSubKey(int64_t nroots, int64_t root_seq_id) const { + CHECK(root_seq_id < nroots); + GpuCliqueKey subkey(*this); + const auto& devs = devices(); + int64_t nranks = devs.size(); + CHECK(nroots <= nranks); + int64_t rank_per_root = nranks / nroots; + int64_t rank_remainder = nranks % nroots; + if (root_seq_id < rank_remainder) { + subkey.root_device_ = devs[root_seq_id * (rank_per_root + 1)]; + } else { + subkey.root_device_ = devs[rank_remainder * (rank_per_root + 1) + (root_seq_id - rank_remainder) * rank_per_root]; + } + return subkey; +} + std::string GpuCliqueKey::ToString() const { std::string group_string = ""; if (!participant_groups_.empty()) { @@ -85,19 +104,20 @@ std::string GpuCliqueKey::ToString() const { } group_string = absl::StrFormat("; groups=[%s]", absl::StrJoin(values, ",")); } - return absl::StrFormat("devices=[%s]; stream=%d%s", + return absl::StrFormat("devices=[%s]; stream=%d%s; root_device=%lld", GlobalDeviceIdsToString(devices()), stream_id_.value(), - group_string); + group_string, root_device_.value()); } void GpuCliqueKey::HashValue(absl::HashState state) const { absl::HashState::combine(std::move(state), devices(), stream_id_, - participant_groups_); + participant_groups_, root_device_); } bool operator==(const GpuCliqueKey& a, const GpuCliqueKey& b) { return a.devices() == b.devices() && a.stream_id_ == b.stream_id_ && - a.participant_groups_ == b.participant_groups_; + a.participant_groups_ == b.participant_groups_ && + a.root_device_ == b.root_device_; } bool operator<(const GpuCliqueKey& a, const GpuCliqueKey& b) { @@ -107,6 +127,9 @@ bool operator<(const GpuCliqueKey& a, const GpuCliqueKey& b) { if (a.devices() < b.devices()) return true; if (b.devices() < a.devices()) return false; + if (a.root_device_ < b.root_device_) return true; + if (b.root_device_ < a.root_device_) return false; + return a.stream_id_.value() < b.stream_id_.value(); } @@ -117,6 +140,9 @@ bool operator>(const GpuCliqueKey& a, const GpuCliqueKey& b) { if (a.devices() > b.devices()) return true; if (b.devices() > a.devices()) return false; + if (a.root_device_ > b.root_device_) return true; + if (b.root_device_ > a.root_device_) return false; + // We still use `<` to order by stream id as we want to acquire sync cliques // before async ones. return a.stream_id_.value() < b.stream_id_.value(); diff --git a/xla/backends/gpu/collectives/gpu_clique_key.h b/xla/backends/gpu/collectives/gpu_clique_key.h index d563db28b0b00c..1f19a177547f64 100644 --- a/xla/backends/gpu/collectives/gpu_clique_key.h +++ b/xla/backends/gpu/collectives/gpu_clique_key.h @@ -68,10 +68,17 @@ class GpuCliqueKey : public CliqueKey { CollectiveStreamId stream_id() const; + // Device generating the unique id for this key + GlobalDeviceId root_device() const; + // Returns true if this clique is a subset of `other`: both cliques have the // same `stream_id` and all clique devices are part of `other` clique. bool IsSubsetOf(const CliqueKey& other) const final; + // Returns a copy of the key (subkey) with the root device properly set given + // nroots and root_seq_id. The subkey is used to generate a NcclCliqueId. + GpuCliqueKey GetSubKey(int64_t nroots, int64_t root_seq_id) const; + // Returns the stream kind for this clique key, stream kind will be used to // specify what configuration to pass for each type of operation. AsyncStreamKind stream_kind() const { return stream_kind_; } @@ -105,6 +112,8 @@ class GpuCliqueKey : public CliqueKey { // Having the participating groups as part of the cache key will prevent such // situations std::vector> participant_groups_; + + GlobalDeviceId root_device_; }; bool operator==(const GpuCliqueKey& a, const GpuCliqueKey& b); diff --git a/xla/backends/gpu/collectives/gpu_clique_locking.cc b/xla/backends/gpu/collectives/gpu_clique_locking.cc index afee5ad405bbc2..fe24361eaa0513 100644 --- a/xla/backends/gpu/collectives/gpu_clique_locking.cc +++ b/xla/backends/gpu/collectives/gpu_clique_locking.cc @@ -206,13 +206,29 @@ InitializeGpuClique(GpuCollectives* collectives, se::StreamExecutor* device, using RendezvousArg = std::pair; + // Check how many roots are needed to initialize the GpuClique + static const int64_t nccl_init_rank_per_root_ratio = + xla::GetDebugOptionsFromFlags().xla_gpu_nccl_init_max_rank_per_root_ratio(); + int64_t nroots = 1; + // Ceiling division to get number of roots + if (nccl_init_rank_per_root_ratio > 0) { + nroots = (nranks + nccl_init_rank_per_root_ratio - 1) / nccl_init_rank_per_root_ratio; + } // Initializes a GpuClique for given device ranks and returns a lock that // gives access to clique communicators. auto initialize = [&](absl::Span args) -> absl::StatusOr { tsl::profiler::TraceMe trace("InitializeGpuClique"); - TF_ASSIGN_OR_RETURN(auto clique_id, clique_id_callback(clique_key)); + CliqueIds clique_ids; + for (int64_t i = 0; i < nroots; ++i) { + GpuCliqueKey subkey = clique_key.GetSubKey(nroots, i); + VLOG(3) << absl::StreamFormat( + "Get CliqueId for sub clique key %s; nroots=%lld", + subkey.ToString(), nroots); + TF_ASSIGN_OR_RETURN(auto clique_id, clique_id_callback(subkey)); + clique_ids.Add(clique_id); + } // Check that all ranks successfully synchronized device activity before // trying to instantiate GPU communicators. @@ -234,13 +250,13 @@ InitializeGpuClique(GpuCollectives* collectives, se::StreamExecutor* device, VLOG(3) << absl::StreamFormat( "Create GPU communicators for clique %s; ranks=[%s]; " - "fingerprint(id)=%d", + "nroots=%lld; fingerprint(id)=%d", clique_key.ToString(), DeviceRanksToString(ranks), - clique_id.fingerprint()); + nroots, clique_ids.fingerprint()); TF_ASSIGN_OR_RETURN( std::vector> created_comms, - collectives->CreateCommunicators(nranks, clique_key, clique_id, ranks, + collectives->CreateCommunicators(nranks, clique_key, clique_ids, ranks, config)); absl::btree_map> comms; @@ -250,15 +266,15 @@ InitializeGpuClique(GpuCollectives* collectives, se::StreamExecutor* device, VLOG(3) << absl::StreamFormat( "Created GPU communicators for clique %s; ranks=[%s]; " - "fingerprint(id)=%d", + "nroots=%lld; fingerprint(id)=%d", clique_key.ToString(), DeviceRanksToString(ranks), - clique_id.fingerprint()); + nroots, clique_ids.fingerprint()); ProcessGpuCliques& cliques = GetProcessGpuCliques(); absl::MutexLock lock(&cliques.mu); // Create a new clique with given clique key and communicators. - auto emplaced = cliques.map.try_emplace(clique_key, clique_key, clique_id, + auto emplaced = cliques.map.try_emplace(clique_key, clique_key, clique_ids, std::move(comms)); // We can have a race to create a clique for a given key, the winner diff --git a/xla/backends/gpu/collectives/gpu_collectives_stub.h b/xla/backends/gpu/collectives/gpu_collectives_stub.h index ad64b910c6c97e..b4c221287f0a66 100644 --- a/xla/backends/gpu/collectives/gpu_collectives_stub.h +++ b/xla/backends/gpu/collectives/gpu_collectives_stub.h @@ -50,7 +50,7 @@ class GpuCollectivesStub : public GpuCollectives { } absl::StatusOr>> - CreateCommunicators(int32_t, const CliqueKey&, const std::optional&, + CreateCommunicators(int32_t, const CliqueKey&, const std::optional&, absl::Span, const Collectives::Config&) final { return UnimplementedError(); diff --git a/xla/backends/gpu/collectives/nccl_collectives.cc b/xla/backends/gpu/collectives/nccl_collectives.cc index faa8caf48a6ec9..8fdd35c415eb8e 100644 --- a/xla/backends/gpu/collectives/nccl_collectives.cc +++ b/xla/backends/gpu/collectives/nccl_collectives.cc @@ -113,19 +113,33 @@ static absl::StatusOr AsNcclUniqueId(const CliqueId& clique_id) { return id; } +static absl::StatusOr> AsNcclUniqueIds(const CliqueIds& clique_ids) { + std::vector ids; + auto ids_vect = clique_ids.data(); + ids.reserve(ids_vect.size()); + for (const auto& clique_id : ids_vect) { + auto id = AsNcclUniqueId(clique_id); + if (!id.ok()) { + return id.status(); + } + ids.push_back(id.value()); + } + return ids; +} + absl::StatusOr>> NcclCollectives::CreateCommunicators(int32_t nranks, const CliqueKey& clique_key, - const std::optional& clique_id, + const std::optional& clique_ids, absl::Span ranks, const Collectives::Config& config) { // With NCCL backend we rely on host to exchange unique clique id. - if (!clique_id.has_value()) { + if (!clique_ids.has_value()) { return InvalidArgument("CliqueId is required to create NCCL communicators"); } VLOG(1) << "Initialize NCCL communicator for " << ranks.size() << " devices" - << "; fingerprint(id)=" << clique_id->fingerprint(); + << "; fingerprint(id)=" << clique_ids->fingerprint(); TF_ASSIGN_OR_RETURN(auto* gpu_config, TryCast(&config)); ncclConfig_t comm_config = AsNcclConfig(*gpu_config); @@ -140,14 +154,15 @@ NcclCollectives::CreateCommunicators(int32_t nranks, for (size_t i = 0; i < ranks.size(); ++i) { VLOG(1) << "Initialize NCCL communicator for rank #" << ranks[i].rank << " of " << nranks - << "; fingerprint(id)=" << clique_id->fingerprint(); + << "; fingerprint(id)=" << clique_ids->fingerprint(); TF_ASSIGN_OR_RETURN(auto* device, TryCast(ranks[i].device)); auto activate_context = device->stream_executor()->Activate(); - TF_ASSIGN_OR_RETURN(auto nccl_unique_id, AsNcclUniqueId(*clique_id)); + TF_ASSIGN_OR_RETURN(auto nccl_unique_ids, AsNcclUniqueIds(*clique_ids)); XLA_NCCL_RETURN_IF_ERROR( - ncclCommInitRankConfig(&comm_handles[i], nranks, nccl_unique_id, - ranks[i].rank.value(), &comm_config)); + ncclCommInitRankScalable(&comm_handles[i], nranks, ranks[i].rank.value(), + nccl_unique_ids.size(), nccl_unique_ids.data(), + &comm_config)); } TF_RETURN_IF_ERROR(GroupEnd()); diff --git a/xla/backends/gpu/collectives/nccl_collectives.h b/xla/backends/gpu/collectives/nccl_collectives.h index c8fb34f6276355..60499d72de3850 100644 --- a/xla/backends/gpu/collectives/nccl_collectives.h +++ b/xla/backends/gpu/collectives/nccl_collectives.h @@ -50,7 +50,7 @@ class NcclCollectives : public GpuCollectives { absl::StatusOr>> CreateCommunicators(int32_t nranks, const CliqueKey& clique_key, - const std::optional& clique_id, + const std::optional& clique_ids, absl::Span ranks, const Collectives::Config& config) final; diff --git a/xla/core/collectives/clique_id.cc b/xla/core/collectives/clique_id.cc index f59b7ce5999692..e94b202bfc793f 100644 --- a/xla/core/collectives/clique_id.cc +++ b/xla/core/collectives/clique_id.cc @@ -40,4 +40,20 @@ uint32_t CliqueId::fingerprint() const { size_t CliqueId::size() const { return data_.size(); } +CliqueIds::CliqueIds(const CliqueId& id) { Add(id); } + +void CliqueIds::Add(const CliqueId& id) { ids_.push_back(id); } + +std::vector CliqueIds::data() const { return ids_; } + +uint32_t CliqueIds::fingerprint() const { + absl::crc32c_t crc(0); + for (const auto& clique_id : ids_) { + crc = absl::ExtendCrc32c(crc, absl::string_view(clique_id.data().data(), + clique_id.data().size())); + } + return static_cast(crc); +} + + } // namespace xla diff --git a/xla/core/collectives/clique_id.h b/xla/core/collectives/clique_id.h index 104e1dbde2d9c8..b129f73d974cf4 100644 --- a/xla/core/collectives/clique_id.h +++ b/xla/core/collectives/clique_id.h @@ -59,6 +59,31 @@ H AbslHashValue(H h, const CliqueId& id) { return H::combine(std::move(h), id.data()); } +// Collection of CliqueIds +class CliqueIds { + public: + CliqueIds() = default; + + CliqueIds(const CliqueId& id); + + void Add(const CliqueId& id); + + std::vector data() const; + + uint32_t fingerprint() const; + + template + friend H AbslHashValue(H h, const CliqueIds& ids); + + private: + std::vector ids_; +}; + +template +H AbslHashValue(H h, const CliqueIds& ids) { + return H::combine(std::move(h), ids.data()); +} + } // namespace xla #endif // XLA_CORE_COLLECTIVES_CLIQUE_ID_H_ diff --git a/xla/core/collectives/collectives.h b/xla/core/collectives/collectives.h index 4b41a0dd440816..2f30a422b1b84e 100644 --- a/xla/core/collectives/collectives.h +++ b/xla/core/collectives/collectives.h @@ -71,7 +71,7 @@ class Collectives { // Creates communicators for given clique key and id. virtual absl::StatusOr>> CreateCommunicators(int32_t nranks, const CliqueKey& clique_key, - const std::optional& clique_id, + const std::optional& clique_ids, absl::Span ranks, const Config& config) = 0; diff --git a/xla/debug_options_flags.cc b/xla/debug_options_flags.cc index b78b87c8a15dff..1a2628aadb8184 100644 --- a/xla/debug_options_flags.cc +++ b/xla/debug_options_flags.cc @@ -161,6 +161,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_enable_nccl_user_buffers(false); opts.set_xla_gpu_enable_nccl_comm_splitting(true); opts.set_xla_gpu_enable_nccl_per_stream_comms(false); + opts.set_xla_gpu_nccl_init_max_rank_per_root_ratio(128); opts.set_xla_gpu_temp_buffer_use_separate_color(false); opts.set_xla_gpu_require_exclusive_lock(false); @@ -1537,6 +1538,14 @@ void MakeDebugOptionsFlags(std::vector* flag_list, "NCCL collective is executed on. This can lead to higher performance if " "NCCL collectives are issued concurrently at the cost of more GPU memory" " usage.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_nccl_init_max_rank_per_root_ratio", + int64_setter_for( + &DebugOptions::set_xla_gpu_nccl_init_max_rank_per_root_ratio), + debug_options->xla_gpu_nccl_init_max_rank_per_root_ratio(), + "Maximum number of ranks associated with a root rank to initialize a " + "NCCL communicator via ncclCommInitRankScalable. " + "A value of zero will lead to a single root.")); flag_list->push_back(tsl::Flag( "xla_gpu_redzone_scratch_max_megabytes", int64_setter_for( diff --git a/xla/pjrt/gpu/nccl_id_store.cc b/xla/pjrt/gpu/nccl_id_store.cc index be0684da19d6f7..a2a72856e6d9f3 100644 --- a/xla/pjrt/gpu/nccl_id_store.cc +++ b/xla/pjrt/gpu/nccl_id_store.cc @@ -49,7 +49,7 @@ absl::StatusOr NcclIdStore::GetNcclUniqueId(const CliqueKey& key) { } } CliqueId clique_id; - int primary_node_id = device_to_node_.at(gpu_key->devices()[0]); + int primary_node_id = device_to_node_.at(gpu_key->root_device()); if (node_id_ == primary_node_id) { TF_ASSIGN_OR_RETURN(clique_id, gpu::GpuCollectives::Default()->CreateUniqueCliqueId()); diff --git a/xla/tsl/cuda/nccl.symbols b/xla/tsl/cuda/nccl.symbols index e5164825373af7..c1647e8ca713d8 100644 --- a/xla/tsl/cuda/nccl.symbols +++ b/xla/tsl/cuda/nccl.symbols @@ -11,6 +11,7 @@ ncclCommGetAsyncError ncclCommInitAll ncclCommInitRank ncclCommInitRankConfig +ncclCommInitRankScalable ncclCommDeregister ncclCommRegister ncclCommSplit @@ -42,6 +43,7 @@ pncclCommGetAsyncError pncclCommInitAll pncclCommInitRank pncclCommInitRankConfig +pncclCommInitRankScalable pncclCommDeregister pncclCommRegister pncclCommSplit diff --git a/xla/xla.proto b/xla/xla.proto index 448cc49c9d9e7f..2b57c193ff14a2 100644 --- a/xla/xla.proto +++ b/xla/xla.proto @@ -858,6 +858,9 @@ message DebugOptions { // Enable NCCL per stream communicators. bool xla_gpu_enable_nccl_per_stream_comms = 276; + // Set number of ranks per root rank for NCCL init. + int64 xla_gpu_nccl_init_max_rank_per_root_ratio = 277; + // If enabled, uses the libnvptxcompiler library to compile PTX to cuBIN. bool xla_gpu_enable_libnvptxcompiler = 269;