Skip to content

Commit

Permalink
[XLA:GPU] Add support for NCCL ncclCommInitRankScalable API
Browse files Browse the repository at this point in the history
  • Loading branch information
nvcastet committed Jan 10, 2025
1 parent 08dcaad commit 98ef02d
Show file tree
Hide file tree
Showing 21 changed files with 161 additions and 39 deletions.
2 changes: 1 addition & 1 deletion xla/backends/cpu/collectives/gloo_collectives.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ GlooCollectives::~GlooCollectives() = default;

absl::StatusOr<std::vector<std::unique_ptr<Communicator>>>
GlooCollectives::CreateCommunicators(const CliqueKey& clique_key,
const std::optional<CliqueId>& clique_id,
const std::optional<CliqueIds>& clique_ids,
absl::Span<const DeviceRank> ranks,
const Config& config) {
std::vector<std::unique_ptr<Communicator>> communicators;
Expand Down
2 changes: 1 addition & 1 deletion xla/backends/cpu/collectives/gloo_collectives.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class GlooCollectives : public CpuCollectives {

absl::StatusOr<std::vector<std::unique_ptr<Communicator>>>
CreateCommunicators(const CliqueKey& clique_key,
const std::optional<CliqueId>& clique_id,
const std::optional<CliqueIds>& clique_ids,
absl::Span<const DeviceRank> ranks,
const Config& config) final;

Expand Down
2 changes: 1 addition & 1 deletion xla/backends/cpu/collectives/in_process_collectives.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ namespace xla::cpu {

absl::StatusOr<std::vector<std::unique_ptr<Communicator>>>
InProcessCollectives::CreateCommunicators(
const CliqueKey& clique_key, const std::optional<CliqueId>& clique_id,
const CliqueKey& clique_key, const std::optional<CliqueIds>& clique_ids,
absl::Span<const DeviceRank> ranks, const Config& config) {
std::vector<std::unique_ptr<Communicator>> communicators;
communicators.reserve(ranks.size());
Expand Down
2 changes: 1 addition & 1 deletion xla/backends/cpu/collectives/in_process_collectives.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class InProcessCollectives : public CpuCollectives {
public:
absl::StatusOr<std::vector<std::unique_ptr<Communicator>>>
CreateCommunicators(const CliqueKey& clique_key,
const std::optional<CliqueId>& clique_id,
const std::optional<CliqueIds>& clique_ids,
absl::Span<const DeviceRank> ranks,
const Config& config) final;
};
Expand Down
2 changes: 1 addition & 1 deletion xla/backends/cpu/collectives/mpi_collectives.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ void MpiCollectives::Finalize() { MPI_Finalize(); }

absl::StatusOr<std::vector<std::unique_ptr<Communicator>>>
MpiCollectives::CreateCommunicators(const CliqueKey& clique_key,
const std::optional<CliqueId>& clique_id,
const std::optional<CliqueIds>& clique_ids,
absl::Span<const DeviceRank> ranks,
const Config& config) {
int flag;
Expand Down
2 changes: 1 addition & 1 deletion xla/backends/cpu/collectives/mpi_collectives.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class MpiCollectives : public CpuCollectives {

absl::StatusOr<std::vector<std::unique_ptr<Communicator>>>
CreateCommunicators(const CliqueKey& clique_key,
const std::optional<CliqueId>& clique_id,
const std::optional<CliqueIds>& clique_ids,
absl::Span<const DeviceRank> ranks,
const Config& config) final;

Expand Down
10 changes: 5 additions & 5 deletions xla/backends/gpu/collectives/gpu_clique.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@ limitations under the License.
namespace xla::gpu {

GpuClique::GpuClique(
GpuCliqueKey key, std::optional<CliqueId> id,
GpuCliqueKey key, std::optional<CliqueIds> ids,
absl::btree_map<RankId, std::unique_ptr<Communicator>> communicators)
: Clique(std::move(communicators)), key_(key), id_(id) {}
: Clique(std::move(communicators)), key_(key), ids_(ids) {}

std::string GpuClique::DebugString() const {
std::string out =
absl::StrFormat("key: %s; fingerprint(id): %d; size: %d; communicators: ",
key_.ToString(), id_.has_value() ? id_->fingerprint() : 0,
key_.ToString(), ids_.has_value() ? ids_->fingerprint() : 0,
num_communicators());
int32_t cnt = 0;
ForEachComm([&](RankId rank, Communicator* comm) {
Expand All @@ -70,9 +70,9 @@ std::string GpuClique::LockableName::ToString(const GpuClique& clique) {
}

LockableGpuClique::LockableGpuClique(
GpuCliqueKey clique_key, std::optional<CliqueId> clique_id,
GpuCliqueKey clique_key, std::optional<CliqueIds> clique_ids,
absl::btree_map<RankId, std::unique_ptr<Communicator>> communicators)
: Lockable(std::move(clique_key), clique_id, std::move(communicators)) {}
: Lockable(std::move(clique_key), clique_ids, std::move(communicators)) {}

absl::Status LockableGpuClique::HealthCheck() const {
return value().HealthCheck();
Expand Down
8 changes: 4 additions & 4 deletions xla/backends/gpu/collectives/gpu_clique.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,15 @@ class LockableGpuClique;
class GpuClique : public Clique {
public:
GpuClique(
GpuCliqueKey key, std::optional<CliqueId> id,
GpuCliqueKey key, std::optional<CliqueIds> ids,
absl::btree_map<RankId, std::unique_ptr<Communicator>> communicators);

// Returns true if clique is local: all communicators belong to current
// process. Non-local cliques spans multiple processes (typically hosts).
bool IsLocal() const { return num_communicators() == key_.devices().size(); }

const GpuCliqueKey& key() const { return key_; }
const std::optional<CliqueId>& id() const { return id_; }
const std::optional<CliqueIds>& ids() const { return ids_; }

std::string DebugString() const final;
absl::Status HealthCheck() const final;
Expand All @@ -62,15 +62,15 @@ class GpuClique : public Clique {
};

GpuCliqueKey key_;
std::optional<CliqueId> id_;
std::optional<CliqueIds> ids_;
};

// A lockable version of GpuClique that guarantees exclusive access to the
// clique communicators.
class LockableGpuClique : public Lockable<GpuClique, GpuClique::LockableName> {
public:
LockableGpuClique(
GpuCliqueKey clique_key, std::optional<CliqueId> clique_id,
GpuCliqueKey clique_key, std::optional<CliqueIds> clique_ids,
absl::btree_map<RankId, std::unique_ptr<Communicator>> communicators);

std::string DebugString() const;
Expand Down
36 changes: 31 additions & 5 deletions xla/backends/gpu/collectives/gpu_clique_key.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ GpuCliqueKey::GpuCliqueKey(
: CliqueKey(std::move(devices)),
stream_id_(stream_id),
stream_kind_(stream_kind),
participant_groups_(std::move(participant_groups)) {
participant_groups_(std::move(participant_groups)),
root_device_(-1) {
for (std::vector<GlobalDeviceId>& group : participant_groups_) {
absl::c_sort(group);
}
Expand All @@ -65,6 +66,8 @@ GpuCliqueKey::GpuCliqueKey(

CollectiveStreamId GpuCliqueKey::stream_id() const { return stream_id_; }

GlobalDeviceId GpuCliqueKey::root_device() const { return root_device_; }

bool GpuCliqueKey::IsSubsetOf(const CliqueKey& other) const {
auto* other_nccl = tsl::down_cast<const GpuCliqueKey*>(&other);
if (other_nccl == nullptr) return false;
Expand All @@ -75,6 +78,22 @@ bool GpuCliqueKey::IsSubsetOf(const CliqueKey& other) const {
});
}

GpuCliqueKey GpuCliqueKey::GetSubKey(int64_t nroots, int64_t root_seq_id) const {
CHECK(root_seq_id < nroots);
GpuCliqueKey subkey(*this);
const auto& devs = devices();
int64_t nranks = devs.size();
CHECK(nroots <= nranks);
int64_t rank_per_root = nranks / nroots;
int64_t rank_remainder = nranks % nroots;
if (root_seq_id < rank_remainder) {
subkey.root_device_ = devs[root_seq_id * (rank_per_root + 1)];
} else {
subkey.root_device_ = devs[rank_remainder * (rank_per_root + 1) + (root_seq_id - rank_remainder) * rank_per_root];
}
return subkey;
}

std::string GpuCliqueKey::ToString() const {
std::string group_string = "";
if (!participant_groups_.empty()) {
Expand All @@ -85,19 +104,20 @@ std::string GpuCliqueKey::ToString() const {
}
group_string = absl::StrFormat("; groups=[%s]", absl::StrJoin(values, ","));
}
return absl::StrFormat("devices=[%s]; stream=%d%s",
return absl::StrFormat("devices=[%s]; stream=%d%s; root_device=%lld",
GlobalDeviceIdsToString(devices()), stream_id_.value(),
group_string);
group_string, root_device_.value());
}

void GpuCliqueKey::HashValue(absl::HashState state) const {
absl::HashState::combine(std::move(state), devices(), stream_id_,
participant_groups_);
participant_groups_, root_device_);
}

bool operator==(const GpuCliqueKey& a, const GpuCliqueKey& b) {
return a.devices() == b.devices() && a.stream_id_ == b.stream_id_ &&
a.participant_groups_ == b.participant_groups_;
a.participant_groups_ == b.participant_groups_ &&
a.root_device_ == b.root_device_;
}

bool operator<(const GpuCliqueKey& a, const GpuCliqueKey& b) {
Expand All @@ -107,6 +127,9 @@ bool operator<(const GpuCliqueKey& a, const GpuCliqueKey& b) {
if (a.devices() < b.devices()) return true;
if (b.devices() < a.devices()) return false;

if (a.root_device_ < b.root_device_) return true;
if (b.root_device_ < a.root_device_) return false;

return a.stream_id_.value() < b.stream_id_.value();
}

Expand All @@ -117,6 +140,9 @@ bool operator>(const GpuCliqueKey& a, const GpuCliqueKey& b) {
if (a.devices() > b.devices()) return true;
if (b.devices() > a.devices()) return false;

if (a.root_device_ > b.root_device_) return true;
if (b.root_device_ > a.root_device_) return false;

// We still use `<` to order by stream id as we want to acquire sync cliques
// before async ones.
return a.stream_id_.value() < b.stream_id_.value();
Expand Down
9 changes: 9 additions & 0 deletions xla/backends/gpu/collectives/gpu_clique_key.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,17 @@ class GpuCliqueKey : public CliqueKey {

CollectiveStreamId stream_id() const;

// Device generating the unique id for this key
GlobalDeviceId root_device() const;

// Returns true if this clique is a subset of `other`: both cliques have the
// same `stream_id` and all clique devices are part of `other` clique.
bool IsSubsetOf(const CliqueKey& other) const final;

// Returns a copy of the key (subkey) with the root device properly set given
// nroots and root_seq_id. The subkey is used to generate a NcclCliqueId.
GpuCliqueKey GetSubKey(int64_t nroots, int64_t root_seq_id) const;

// Returns the stream kind for this clique key, stream kind will be used to
// specify what configuration to pass for each type of operation.
AsyncStreamKind stream_kind() const { return stream_kind_; }
Expand Down Expand Up @@ -105,6 +112,8 @@ class GpuCliqueKey : public CliqueKey {
// Having the participating groups as part of the cache key will prevent such
// situations
std::vector<std::vector<GlobalDeviceId>> participant_groups_;

GlobalDeviceId root_device_;
};

bool operator==(const GpuCliqueKey& a, const GpuCliqueKey& b);
Expand Down
31 changes: 24 additions & 7 deletions xla/backends/gpu/collectives/gpu_cliques.cc
Original file line number Diff line number Diff line change
Expand Up @@ -205,13 +205,30 @@ InitializeGpuClique(GpuCollectives* collectives, se::StreamExecutor* device,

using RendezvousArg = std::pair<DeviceRank, /*synchronized=*/bool>;

// Check how many roots are needed to initialize the GpuClique
static const int64_t nccl_init_rank_per_root_ratio =
xla::GetDebugOptionsFromFlags().xla_gpu_nccl_init_max_rank_per_root_ratio();
int64_t nranks = clique_key.num_devices();
int64_t nroots = 1;
// Ceiling division to get number of roots
if (nccl_init_rank_per_root_ratio > 0) {
nroots = (nranks + nccl_init_rank_per_root_ratio - 1) / nccl_init_rank_per_root_ratio;
}
// Initializes a GpuClique for given device ranks and returns a lock that
// gives access to clique communicators.
auto initialize = [&](absl::Span<const RendezvousArg* const> args)
-> absl::StatusOr<LockableGpuClique::Lock> {
tsl::profiler::TraceMe trace("InitializeGpuClique");

TF_ASSIGN_OR_RETURN(auto clique_id, clique_id_callback(clique_key));
CliqueIds clique_ids;
for (int64_t i = 0; i < nroots; ++i) {
GpuCliqueKey subkey = clique_key.GetSubKey(nroots, i);
VLOG(3) << absl::StreamFormat(
"Get CliqueId for sub clique key %s; nroots=%lld",
subkey.ToString(), nroots);
TF_ASSIGN_OR_RETURN(auto clique_id, clique_id_callback(subkey));
clique_ids.Add(clique_id);
}

// Check that all ranks successfully synchronized device activity before
// trying to instantiate GPU communicators.
Expand All @@ -233,13 +250,13 @@ InitializeGpuClique(GpuCollectives* collectives, se::StreamExecutor* device,

VLOG(3) << absl::StreamFormat(
"Create GPU communicators for clique %s; ranks=[%s]; "
"fingerprint(id)=%d",
"nroots=%lld; fingerprint(id)=%d",
clique_key.ToString(), DeviceRanksToString(ranks),
clique_id.fingerprint());
nroots, clique_ids.fingerprint());

TF_ASSIGN_OR_RETURN(
std::vector<std::unique_ptr<Communicator>> created_comms,
collectives->CreateCommunicators(clique_key, clique_id, ranks, config));
collectives->CreateCommunicators(clique_key, clique_ids, ranks, config));

absl::btree_map<RankId, std::unique_ptr<Communicator>> comms;
for (size_t i = 0; i < ranks.size(); ++i) {
Expand All @@ -248,15 +265,15 @@ InitializeGpuClique(GpuCollectives* collectives, se::StreamExecutor* device,

VLOG(3) << absl::StreamFormat(
"Created GPU communicators for clique %s; ranks=[%s]; "
"fingerprint(id)=%d",
"nroots=%lld; fingerprint(id)=%d",
clique_key.ToString(), DeviceRanksToString(ranks),
clique_id.fingerprint());
nroots, clique_ids.fingerprint());

ProcessGpuCliques& cliques = GetProcessGpuCliques();
absl::MutexLock lock(&cliques.mu);

// Create a new clique with given clique key and communicators.
auto emplaced = cliques.map.try_emplace(clique_key, clique_key, clique_id,
auto emplaced = cliques.map.try_emplace(clique_key, clique_key, clique_ids,
std::move(comms));

// We can have a race to create a clique for a given key, the winner
Expand Down
2 changes: 1 addition & 1 deletion xla/backends/gpu/collectives/gpu_collectives_stub.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class GpuCollectivesStub : public GpuCollectives {
}

absl::StatusOr<std::vector<std::unique_ptr<Communicator>>>
CreateCommunicators(const CliqueKey&, const std::optional<CliqueId>&,
CreateCommunicators(const CliqueKey&, const std::optional<CliqueIds>&,
absl::Span<const DeviceRank>,
const Collectives::Config&) final {
return UnimplementedError();
Expand Down
31 changes: 23 additions & 8 deletions xla/backends/gpu/collectives/nccl_collectives.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,18 +114,32 @@ static absl::StatusOr<ncclUniqueId> AsNcclUniqueId(const CliqueId& clique_id) {
return id;
}

static absl::StatusOr<std::vector<ncclUniqueId>> AsNcclUniqueIds(const CliqueIds& clique_ids) {
std::vector<ncclUniqueId> ids;
auto ids_vect = clique_ids.data();
ids.reserve(ids_vect.size());
for (const auto& clique_id : ids_vect) {
auto id = AsNcclUniqueId(clique_id);
if (!id.ok()) {
return id.status();
}
ids.push_back(id.value());
}
return ids;
}

absl::StatusOr<std::vector<std::unique_ptr<Communicator>>>
NcclCollectives::CreateCommunicators(const CliqueKey& clique_key,
const std::optional<CliqueId>& clique_id,
const std::optional<CliqueIds>& clique_ids,
absl::Span<const DeviceRank> ranks,
const Collectives::Config& config) {
// With NCCL backend we rely on host to exchange unique clique id.
if (!clique_id.has_value()) {
if (!clique_ids.has_value()) {
return InvalidArgument("CliqueId is required to create NCCL communicators");
}

VLOG(1) << "Initialize NCCL communicator for " << ranks.size() << " devices"
<< "; fingerprint(id)=" << clique_id->fingerprint();
<< "; fingerprint(id)=" << clique_ids->fingerprint();

TF_ASSIGN_OR_RETURN(auto* gpu_config, TryCast(&config));
ncclConfig_t comm_config = AsNcclConfig(*gpu_config);
Expand All @@ -140,14 +154,15 @@ NcclCollectives::CreateCommunicators(const CliqueKey& clique_key,
for (size_t i = 0; i < ranks.size(); ++i) {
VLOG(1) << "Initialize NCCL communicator for rank #" << ranks[i].rank
<< " of " << clique_key.num_devices()
<< "; fingerprint(id)=" << clique_id->fingerprint();
<< "; fingerprint(id)=" << clique_ids->fingerprint();
TF_ASSIGN_OR_RETURN(auto* device, TryCast(ranks[i].device));
auto activate_context = device->stream_executor()->Activate();

TF_ASSIGN_OR_RETURN(auto nccl_unique_id, AsNcclUniqueId(*clique_id));
XLA_NCCL_RETURN_IF_ERROR(ncclCommInitRankConfig(
&comm_handles[i], clique_key.num_devices(), nccl_unique_id,
ranks[i].rank.value(), &comm_config));
TF_ASSIGN_OR_RETURN(auto nccl_unique_ids, AsNcclUniqueIds(*clique_ids));
XLA_NCCL_RETURN_IF_ERROR(
ncclCommInitRankScalable(&comm_handles[i], clique_key.num_devices(),
ranks[i].rank.value(), nccl_unique_ids.size(),
nccl_unique_ids.data(), &comm_config));
}
TF_RETURN_IF_ERROR(GroupEnd());

Expand Down
2 changes: 1 addition & 1 deletion xla/backends/gpu/collectives/nccl_collectives.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class NcclCollectives : public GpuCollectives {

absl::StatusOr<std::vector<std::unique_ptr<Communicator>>>
CreateCommunicators(const CliqueKey& clique_key,
const std::optional<CliqueId>& clique_id,
const std::optional<CliqueIds>& clique_ids,
absl::Span<const DeviceRank> ranks,
const Collectives::Config& config) final;

Expand Down
16 changes: 16 additions & 0 deletions xla/core/collectives/clique_id.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,20 @@ uint32_t CliqueId::fingerprint() const {

size_t CliqueId::size() const { return data_.size(); }

CliqueIds::CliqueIds(const CliqueId& id) { Add(id); }

void CliqueIds::Add(const CliqueId& id) { ids_.push_back(id); }

std::vector<CliqueId> CliqueIds::data() const { return ids_; }

uint32_t CliqueIds::fingerprint() const {
absl::crc32c_t crc(0);
for (const auto& clique_id : ids_) {
crc = absl::ExtendCrc32c(crc, absl::string_view(clique_id.data().data(),
clique_id.data().size()));
}
return static_cast<uint32_t>(crc);
}


} // namespace xla
Loading

0 comments on commit 98ef02d

Please sign in to comment.