Skip to content

Commit

Permalink
[xla:collectives] Remove redundant nranks argument from collectives API
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 713705217
  • Loading branch information
ezhulenev authored and Google-ML-Automation committed Jan 9, 2025
1 parent 2f6eabb commit 488ffae
Show file tree
Hide file tree
Showing 14 changed files with 32 additions and 39 deletions.
7 changes: 3 additions & 4 deletions xla/backends/cpu/collectives/cpu_cliques.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,9 @@ absl::StatusOr<Communicator*> AcquireCommunicator(
CpuCollectives::DeviceRank device_rank(/*device=*/nullptr, rank);
CpuCollectives::Config config;

TF_ASSIGN_OR_RETURN(
std::vector<std::unique_ptr<Communicator>> communicators,
collectives->CreateCommunicators(clique_key.num_devices(), clique_key,
std::nullopt, {device_rank}, config));
TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<Communicator>> communicators,
collectives->CreateCommunicators(clique_key, std::nullopt,
{device_rank}, config));

// We expect to create communicators lazily on at a time.
if (communicators.size() != 1) {
Expand Down
4 changes: 1 addition & 3 deletions xla/backends/cpu/collectives/gloo_collectives.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ limitations under the License.
#include "xla/backends/cpu/collectives/gloo_collectives.h"

#include <cstddef>
#include <cstdint>
#include <exception>
#include <memory>
#include <optional>
Expand Down Expand Up @@ -52,8 +51,7 @@ GlooCollectives::GlooCollectives(
GlooCollectives::~GlooCollectives() = default;

absl::StatusOr<std::vector<std::unique_ptr<Communicator>>>
GlooCollectives::CreateCommunicators(int32_t nranks,
const CliqueKey& clique_key,
GlooCollectives::CreateCommunicators(const CliqueKey& clique_key,
const std::optional<CliqueId>& clique_id,
absl::Span<const DeviceRank> ranks,
const Config& config) {
Expand Down
2 changes: 1 addition & 1 deletion xla/backends/cpu/collectives/gloo_collectives.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class GlooCollectives : public CpuCollectives {
~GlooCollectives() override;

absl::StatusOr<std::vector<std::unique_ptr<Communicator>>>
CreateCommunicators(int32_t nranks, const CliqueKey& clique_key,
CreateCommunicators(const CliqueKey& clique_key,
const std::optional<CliqueId>& clique_id,
absl::Span<const DeviceRank> ranks,
const Config& config) final;
Expand Down
8 changes: 4 additions & 4 deletions xla/backends/cpu/collectives/gloo_collectives_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,10 @@ absl::StatusOr<std::unique_ptr<Communicator>> GetCommunicator(
CpuCliqueKey clique_key(global_devices);
CpuCollectives::DeviceRank device_rank(nullptr, RankId(rank));

TF_ASSIGN_OR_RETURN(auto communicators,
collectives->CreateCommunicators(
global_devices.size(), clique_key, std::nullopt,
{device_rank}, CpuCollectives::Config()));
TF_ASSIGN_OR_RETURN(
auto communicators,
collectives->CreateCommunicators(clique_key, std::nullopt, {device_rank},
CpuCollectives::Config()));

return std::move(communicators[0]);
}
Expand Down
4 changes: 1 addition & 3 deletions xla/backends/cpu/collectives/in_process_collectives.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ limitations under the License.
#include "xla/backends/cpu/collectives/in_process_collectives.h"

#include <cstddef>
#include <cstdint>
#include <memory>
#include <optional>
#include <vector>
Expand All @@ -35,8 +34,7 @@ namespace xla::cpu {

absl::StatusOr<std::vector<std::unique_ptr<Communicator>>>
InProcessCollectives::CreateCommunicators(
int32_t nranks, const CliqueKey& clique_key,
const std::optional<CliqueId>& clique_id,
const CliqueKey& clique_key, const std::optional<CliqueId>& clique_id,
absl::Span<const DeviceRank> ranks, const Config& config) {
absl::MutexLock lock(&mu_);

Expand Down
2 changes: 1 addition & 1 deletion xla/backends/cpu/collectives/in_process_collectives.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ namespace xla::cpu {
class InProcessCollectives : public CpuCollectives {
public:
absl::StatusOr<std::vector<std::unique_ptr<Communicator>>>
CreateCommunicators(int32_t nranks, const CliqueKey& clique_key,
CreateCommunicators(const CliqueKey& clique_key,
const std::optional<CliqueId>& clique_id,
absl::Span<const DeviceRank> ranks,
const Config& config) final;
Expand Down
3 changes: 1 addition & 2 deletions xla/backends/cpu/collectives/mpi_collectives.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ limitations under the License.
#include "xla/backends/cpu/collectives/mpi_collectives.h"

#include <cstddef>
#include <cstdint>
#include <memory>
#include <optional>
#include <vector>
Expand Down Expand Up @@ -45,7 +44,7 @@ void MpiCollectives::Init() {
void MpiCollectives::Finalize() { MPI_Finalize(); }

absl::StatusOr<std::vector<std::unique_ptr<Communicator>>>
MpiCollectives::CreateCommunicators(int32_t nranks, const CliqueKey& clique_key,
MpiCollectives::CreateCommunicators(const CliqueKey& clique_key,
const std::optional<CliqueId>& clique_id,
absl::Span<const DeviceRank> ranks,
const Config& config) {
Expand Down
2 changes: 1 addition & 1 deletion xla/backends/cpu/collectives/mpi_collectives.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class MpiCollectives : public CpuCollectives {
void Finalize();

absl::StatusOr<std::vector<std::unique_ptr<Communicator>>>
CreateCommunicators(int32_t nranks, const CliqueKey& clique_key,
CreateCommunicators(const CliqueKey& clique_key,
const std::optional<CliqueId>& clique_id,
absl::Span<const DeviceRank> ranks,
const Config& config) final;
Expand Down
9 changes: 5 additions & 4 deletions xla/backends/gpu/collectives/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ cc_library(
"//xla/service:lockable",
"//xla/service:rendezvous",
"//xla/stream_executor:stream_executor_h",
"//xla/tsl/platform:env",
"//xla/tsl/platform:errors",
"//xla/tsl/platform:logging",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:btree",
Expand All @@ -119,11 +123,7 @@ cc_library(
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
"@com_google_absl//absl/types:span",
"@tsl//tsl/platform:env",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:hash",
"@tsl//tsl/platform:logging",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/profiler/lib:traceme",
],
)
Expand Down Expand Up @@ -214,6 +214,7 @@ cc_library(
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:span",
"@tsl//tsl/platform:casts",
] + if_cuda_is_configured([
Expand Down
12 changes: 5 additions & 7 deletions xla/backends/gpu/collectives/gpu_clique_locking.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,12 @@ limitations under the License.
#include "xla/service/lockable.h"
#include "xla/service/rendezvous.h"
#include "xla/stream_executor/stream_executor.h"
#include "xla/tsl/platform/env.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/logging.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/util.h"
#include "tsl/platform/env.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/hash.h"
#include "tsl/platform/logging.h"
#include "tsl/platform/statusor.h"
#include "tsl/profiler/lib/traceme.h"

namespace xla::gpu {
Expand Down Expand Up @@ -197,7 +197,6 @@ InitializeGpuClique(GpuCollectives* collectives, se::StreamExecutor* device,
const GpuCollectives::CliqueIdCallback& clique_id_callback,
int32_t num_local_participants, RankId rank,
const GpuCollectives::Config& config) {
int nranks = clique_key.devices().size();
VLOG(3) << "Initialize GPU clique " << clique_key.ToString() << " rank #"
<< rank << "; num_local_participants=" << num_local_participants;

Expand Down Expand Up @@ -240,8 +239,7 @@ InitializeGpuClique(GpuCollectives* collectives, se::StreamExecutor* device,

TF_ASSIGN_OR_RETURN(
std::vector<std::unique_ptr<Communicator>> created_comms,
collectives->CreateCommunicators(nranks, clique_key, clique_id, ranks,
config));
collectives->CreateCommunicators(clique_key, clique_id, ranks, config));

absl::btree_map<RankId, std::unique_ptr<Communicator>> comms;
for (size_t i = 0; i < ranks.size(); ++i) {
Expand Down
2 changes: 1 addition & 1 deletion xla/backends/gpu/collectives/gpu_collectives_stub.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class GpuCollectivesStub : public GpuCollectives {
}

absl::StatusOr<std::vector<std::unique_ptr<Communicator>>>
CreateCommunicators(int32_t, const CliqueKey&, const std::optional<CliqueId>&,
CreateCommunicators(const CliqueKey&, const std::optional<CliqueId>&,
absl::Span<const DeviceRank>,
const Collectives::Config&) final {
return UnimplementedError();
Expand Down
12 changes: 6 additions & 6 deletions xla/backends/gpu/collectives/nccl_collectives.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "xla/backends/gpu/collectives/gpu_collectives.h"
#include "xla/backends/gpu/collectives/nccl_communicator.h"
Expand Down Expand Up @@ -114,8 +115,7 @@ static absl::StatusOr<ncclUniqueId> AsNcclUniqueId(const CliqueId& clique_id) {
}

absl::StatusOr<std::vector<std::unique_ptr<Communicator>>>
NcclCollectives::CreateCommunicators(int32_t nranks,
const CliqueKey& clique_key,
NcclCollectives::CreateCommunicators(const CliqueKey& clique_key,
const std::optional<CliqueId>& clique_id,
absl::Span<const DeviceRank> ranks,
const Collectives::Config& config) {
Expand All @@ -139,15 +139,15 @@ NcclCollectives::CreateCommunicators(int32_t nranks,
TF_RETURN_IF_ERROR(GroupStart());
for (size_t i = 0; i < ranks.size(); ++i) {
VLOG(1) << "Initialize NCCL communicator for rank #" << ranks[i].rank
<< " of " << nranks
<< " of " << clique_key.num_devices()
<< "; fingerprint(id)=" << clique_id->fingerprint();
TF_ASSIGN_OR_RETURN(auto* device, TryCast(ranks[i].device));
auto activate_context = device->stream_executor()->Activate();

TF_ASSIGN_OR_RETURN(auto nccl_unique_id, AsNcclUniqueId(*clique_id));
XLA_NCCL_RETURN_IF_ERROR(
ncclCommInitRankConfig(&comm_handles[i], nranks, nccl_unique_id,
ranks[i].rank.value(), &comm_config));
XLA_NCCL_RETURN_IF_ERROR(ncclCommInitRankConfig(
&comm_handles[i], clique_key.num_devices(), nccl_unique_id,
ranks[i].rank.value(), &comm_config));
}
TF_RETURN_IF_ERROR(GroupEnd());

Expand Down
2 changes: 1 addition & 1 deletion xla/backends/gpu/collectives/nccl_collectives.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class NcclCollectives : public GpuCollectives {
absl::Status GroupEnd() final;

absl::StatusOr<std::vector<std::unique_ptr<Communicator>>>
CreateCommunicators(int32_t nranks, const CliqueKey& clique_key,
CreateCommunicators(const CliqueKey& clique_key,
const std::optional<CliqueId>& clique_id,
absl::Span<const DeviceRank> ranks,
const Collectives::Config& config) final;
Expand Down
2 changes: 1 addition & 1 deletion xla/core/collectives/collectives.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class Collectives {

// Creates communicators for given clique key and id.
virtual absl::StatusOr<std::vector<std::unique_ptr<Communicator>>>
CreateCommunicators(int32_t nranks, const CliqueKey& clique_key,
CreateCommunicators(const CliqueKey& clique_key,
const std::optional<CliqueId>& clique_id,
absl::Span<const DeviceRank> ranks,
const Config& config) = 0;
Expand Down

0 comments on commit 488ffae

Please sign in to comment.