diff --git a/xla/backends/gpu/collectives/gpu_clique_key.cc b/xla/backends/gpu/collectives/gpu_clique_key.cc index decd4d46b71d4..049afc36747d9 100644 --- a/xla/backends/gpu/collectives/gpu_clique_key.cc +++ b/xla/backends/gpu/collectives/gpu_clique_key.cc @@ -78,20 +78,24 @@ bool GpuCliqueKey::IsSubsetOf(const CliqueKey& other) const { }); } -GpuCliqueKey GpuCliqueKey::GetSubKey(int64_t nroots, int64_t root_seq_id) const { - CHECK(root_seq_id < nroots); - GpuCliqueKey subkey(*this); +std::vector GpuCliqueKey::GetSubKeys(int64_t nroots) const { const auto& devs = devices(); int64_t nranks = devs.size(); CHECK(nroots <= nranks); int64_t rank_per_root = nranks / nroots; - int64_t rank_remainder = nranks % nroots; - if (root_seq_id < rank_remainder) { - subkey.root_device_ = devs[root_seq_id * (rank_per_root + 1)]; - } else { - subkey.root_device_ = devs[rank_remainder * (rank_per_root + 1) + (root_seq_id - rank_remainder) * rank_per_root]; + int64_t rank_rem = nranks % nroots; + std::vector subkeys; + for (int64_t i = 0; i < nroots; ++i) { + GpuCliqueKey subkey(*this); + if (i < rank_rem) { + subkey.root_device_ = devs[i * (rank_per_root + 1)]; + } else { + subkey.root_device_ = devs[rank_rem * (rank_per_root + 1) + + (i - rank_rem) * rank_per_root]; + } + subkeys.push_back(subkey); } - return subkey; + return subkeys; } std::string GpuCliqueKey::ToString() const { diff --git a/xla/backends/gpu/collectives/gpu_clique_key.h b/xla/backends/gpu/collectives/gpu_clique_key.h index 1f19a177547f6..30d066ac0f277 100644 --- a/xla/backends/gpu/collectives/gpu_clique_key.h +++ b/xla/backends/gpu/collectives/gpu_clique_key.h @@ -75,9 +75,12 @@ class GpuCliqueKey : public CliqueKey { // same `stream_id` and all clique devices are part of `other` clique. bool IsSubsetOf(const CliqueKey& other) const final; - // Returns a copy of the key (subkey) with the root device properly set given - // nroots and root_seq_id. The subkey is used to generate a NcclCliqueId. - GpuCliqueKey GetSubKey(int64_t nroots, int64_t root_seq_id) const; + + // For multi-root initialization, generate `nroots` copies (subkeys) of the + // key each with a different root device. Root devices are distributed evenly + // accross the ranks. The subkeys are used to exchange the CliqueIds during + // clique initialization. + std::vector GetSubKeys(int64_t nroots) const; // Returns the stream kind for this clique key, stream kind will be used to // specify what configuration to pass for each type of operation. diff --git a/xla/backends/gpu/collectives/gpu_cliques.cc b/xla/backends/gpu/collectives/gpu_cliques.cc index 3eb958a17a07e..670060f78dc9f 100644 --- a/xla/backends/gpu/collectives/gpu_cliques.cc +++ b/xla/backends/gpu/collectives/gpu_cliques.cc @@ -221,8 +221,8 @@ InitializeGpuClique(GpuCollectives* collectives, se::StreamExecutor* device, tsl::profiler::TraceMe trace("InitializeGpuClique"); CliqueIds clique_ids; - for (int64_t i = 0; i < nroots; ++i) { - GpuCliqueKey subkey = clique_key.GetSubKey(nroots, i); + const auto& subkeys = clique_key.GetSubKeys(nroots); + for (const auto& subkey : subkeys) { VLOG(3) << absl::StreamFormat( "Get CliqueId for sub clique key %s; nroots=%lld", subkey.ToString(), nroots); diff --git a/xla/backends/gpu/collectives/nccl_collectives.cc b/xla/backends/gpu/collectives/nccl_collectives.cc index 6e9572b4adb5e..b475ddf6fde69 100644 --- a/xla/backends/gpu/collectives/nccl_collectives.cc +++ b/xla/backends/gpu/collectives/nccl_collectives.cc @@ -133,8 +133,8 @@ NcclCollectives::CreateCommunicators(const CliqueKey& clique_key, const std::optional& clique_ids, absl::Span ranks, const Collectives::Config& config) { - // With NCCL backend we rely on host to exchange unique clique id. - if (!clique_ids.has_value()) { + // With NCCL backend we rely on host to exchange unique clique ids. + if (!clique_ids.has_value() || clique_ids->data().empty()) { return InvalidArgument("CliqueId is required to create NCCL communicators"); } @@ -154,7 +154,8 @@ NcclCollectives::CreateCommunicators(const CliqueKey& clique_key, for (size_t i = 0; i < ranks.size(); ++i) { VLOG(1) << "Initialize NCCL communicator for rank #" << ranks[i].rank << " of " << clique_key.num_devices() - << "; fingerprint(id)=" << clique_ids->fingerprint(); + << "; fingerprint(id)=" << clique_ids->fingerprint() + << "; size(id)=" << clique_ids->data().size(); TF_ASSIGN_OR_RETURN(auto* device, TryCast(ranks[i].device)); auto activate_context = device->stream_executor()->Activate(); diff --git a/xla/core/collectives/clique_id.h b/xla/core/collectives/clique_id.h index b129f73d974cf..18642cc835885 100644 --- a/xla/core/collectives/clique_id.h +++ b/xla/core/collectives/clique_id.h @@ -59,7 +59,8 @@ H AbslHashValue(H h, const CliqueId& id) { return H::combine(std::move(h), id.data()); } -// Collection of CliqueIds +// An evenly distributed list of root ranks (cliqueIds) to spread communication +// during clique setup. class CliqueIds { public: CliqueIds() = default;