Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
nvcastet committed Jan 10, 2025
1 parent 98ef02d commit 06fcca9
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 18 deletions.
22 changes: 13 additions & 9 deletions xla/backends/gpu/collectives/gpu_clique_key.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,20 +78,24 @@ bool GpuCliqueKey::IsSubsetOf(const CliqueKey& other) const {
});
}

GpuCliqueKey GpuCliqueKey::GetSubKey(int64_t nroots, int64_t root_seq_id) const {
CHECK(root_seq_id < nroots);
GpuCliqueKey subkey(*this);
absl::Span<const GpuCliqueKey> GpuCliqueKey::GetSubKeys(int64_t nroots) const {
const auto& devs = devices();
int64_t nranks = devs.size();
CHECK(nroots <= nranks);
int64_t rank_per_root = nranks / nroots;
int64_t rank_remainder = nranks % nroots;
if (root_seq_id < rank_remainder) {
subkey.root_device_ = devs[root_seq_id * (rank_per_root + 1)];
} else {
subkey.root_device_ = devs[rank_remainder * (rank_per_root + 1) + (root_seq_id - rank_remainder) * rank_per_root];
int64_t rank_rem = nranks % nroots;
std::vector<GpuCliqueKey> subkeys;
for (int64_t i = 0; i < nroots; ++i) {
GpuCliqueKey subkey(*this);
if (i < rank_rem) {
subkey.root_device_ = devs[i * (rank_per_root + 1)];
} else {
subkey.root_device_ = devs[rank_rem * (rank_per_root + 1)
+ (i - rank_rem) * rank_per_root];
}
subkeys.push_back(subkey);
}
return subkey;
return subkeys;
}

std::string GpuCliqueKey::ToString() const {
Expand Down
9 changes: 6 additions & 3 deletions xla/backends/gpu/collectives/gpu_clique_key.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,12 @@ class GpuCliqueKey : public CliqueKey {
// same `stream_id` and all clique devices are part of `other` clique.
bool IsSubsetOf(const CliqueKey& other) const final;

// Returns a copy of the key (subkey) with the root device properly set given
// nroots and root_seq_id. The subkey is used to generate a NcclCliqueId.
GpuCliqueKey GetSubKey(int64_t nroots, int64_t root_seq_id) const;

// For multi-root initialization, generate `nroots` copies (subkeys) of the
// key each with a different root device. Root devices are distributed evenly
// accross the ranks. The subkeys are used to exchange the CliqueIds during
// clique initialization.
absl::Span<const GpuCliqueKey> GetSubKeys(int64_t nroots) const;

// Returns the stream kind for this clique key, stream kind will be used to
// specify what configuration to pass for each type of operation.
Expand Down
4 changes: 2 additions & 2 deletions xla/backends/gpu/collectives/gpu_cliques.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,8 @@ InitializeGpuClique(GpuCollectives* collectives, se::StreamExecutor* device,
tsl::profiler::TraceMe trace("InitializeGpuClique");

CliqueIds clique_ids;
for (int64_t i = 0; i < nroots; ++i) {
GpuCliqueKey subkey = clique_key.GetSubKey(nroots, i);
const auto& subkeys = clique_key.GetSubKeys(nroots);
for (const auto& subkey : subkeys) {
VLOG(3) << absl::StreamFormat(
"Get CliqueId for sub clique key %s; nroots=%lld",
subkey.ToString(), nroots);
Expand Down
7 changes: 4 additions & 3 deletions xla/backends/gpu/collectives/nccl_collectives.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,8 @@ NcclCollectives::CreateCommunicators(const CliqueKey& clique_key,
const std::optional<CliqueIds>& clique_ids,
absl::Span<const DeviceRank> ranks,
const Collectives::Config& config) {
// With NCCL backend we rely on host to exchange unique clique id.
if (!clique_ids.has_value()) {
// With NCCL backend we rely on host to exchange unique clique ids.
if (!clique_ids.has_value() || clique_ids->data().empty()) {
return InvalidArgument("CliqueId is required to create NCCL communicators");
}

Expand All @@ -154,7 +154,8 @@ NcclCollectives::CreateCommunicators(const CliqueKey& clique_key,
for (size_t i = 0; i < ranks.size(); ++i) {
VLOG(1) << "Initialize NCCL communicator for rank #" << ranks[i].rank
<< " of " << clique_key.num_devices()
<< "; fingerprint(id)=" << clique_ids->fingerprint();
<< "; fingerprint(id)=" << clique_ids->fingerprint()
<< "; size(id)=" << clique_ids->data().size();
TF_ASSIGN_OR_RETURN(auto* device, TryCast(ranks[i].device));
auto activate_context = device->stream_executor()->Activate();

Expand Down
3 changes: 2 additions & 1 deletion xla/core/collectives/clique_id.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ H AbslHashValue(H h, const CliqueId& id) {
return H::combine(std::move(h), id.data());
}

// Collection of CliqueIds
// An evenly distributed list of root ranks (cliqueIds) to spread communication
// during clique setup.
class CliqueIds {
public:
CliqueIds() = default;
Expand Down

0 comments on commit 06fcca9

Please sign in to comment.