From 488ffae7c1ce37285ad77c9381c8bb0aae352645 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Thu, 9 Jan 2025 09:49:27 -0800 Subject: [PATCH] [xla:collectives] Remove redundant nranks argument from collectives API PiperOrigin-RevId: 713705217 --- xla/backends/cpu/collectives/cpu_cliques.cc | 7 +++---- xla/backends/cpu/collectives/gloo_collectives.cc | 4 +--- xla/backends/cpu/collectives/gloo_collectives.h | 2 +- .../cpu/collectives/gloo_collectives_test.cc | 8 ++++---- .../cpu/collectives/in_process_collectives.cc | 4 +--- .../cpu/collectives/in_process_collectives.h | 2 +- xla/backends/cpu/collectives/mpi_collectives.cc | 3 +-- xla/backends/cpu/collectives/mpi_collectives.h | 2 +- xla/backends/gpu/collectives/BUILD | 9 +++++---- xla/backends/gpu/collectives/gpu_clique_locking.cc | 12 +++++------- xla/backends/gpu/collectives/gpu_collectives_stub.h | 2 +- xla/backends/gpu/collectives/nccl_collectives.cc | 12 ++++++------ xla/backends/gpu/collectives/nccl_collectives.h | 2 +- xla/core/collectives/collectives.h | 2 +- 14 files changed, 32 insertions(+), 39 deletions(-) diff --git a/xla/backends/cpu/collectives/cpu_cliques.cc b/xla/backends/cpu/collectives/cpu_cliques.cc index 6e6c437256ad1..c52b400e4b579 100644 --- a/xla/backends/cpu/collectives/cpu_cliques.cc +++ b/xla/backends/cpu/collectives/cpu_cliques.cc @@ -99,10 +99,9 @@ absl::StatusOr AcquireCommunicator( CpuCollectives::DeviceRank device_rank(/*device=*/nullptr, rank); CpuCollectives::Config config; - TF_ASSIGN_OR_RETURN( - std::vector> communicators, - collectives->CreateCommunicators(clique_key.num_devices(), clique_key, - std::nullopt, {device_rank}, config)); + TF_ASSIGN_OR_RETURN(std::vector> communicators, + collectives->CreateCommunicators(clique_key, std::nullopt, + {device_rank}, config)); // We expect to create communicators lazily on at a time. if (communicators.size() != 1) { diff --git a/xla/backends/cpu/collectives/gloo_collectives.cc b/xla/backends/cpu/collectives/gloo_collectives.cc index 5880704f3c680..eb8705b81fd5f 100644 --- a/xla/backends/cpu/collectives/gloo_collectives.cc +++ b/xla/backends/cpu/collectives/gloo_collectives.cc @@ -16,7 +16,6 @@ limitations under the License. #include "xla/backends/cpu/collectives/gloo_collectives.h" #include -#include #include #include #include @@ -52,8 +51,7 @@ GlooCollectives::GlooCollectives( GlooCollectives::~GlooCollectives() = default; absl::StatusOr>> -GlooCollectives::CreateCommunicators(int32_t nranks, - const CliqueKey& clique_key, +GlooCollectives::CreateCommunicators(const CliqueKey& clique_key, const std::optional& clique_id, absl::Span ranks, const Config& config) { diff --git a/xla/backends/cpu/collectives/gloo_collectives.h b/xla/backends/cpu/collectives/gloo_collectives.h index 740e8ddc8bc21..9b52a05ea5e34 100644 --- a/xla/backends/cpu/collectives/gloo_collectives.h +++ b/xla/backends/cpu/collectives/gloo_collectives.h @@ -41,7 +41,7 @@ class GlooCollectives : public CpuCollectives { ~GlooCollectives() override; absl::StatusOr>> - CreateCommunicators(int32_t nranks, const CliqueKey& clique_key, + CreateCommunicators(const CliqueKey& clique_key, const std::optional& clique_id, absl::Span ranks, const Config& config) final; diff --git a/xla/backends/cpu/collectives/gloo_collectives_test.cc b/xla/backends/cpu/collectives/gloo_collectives_test.cc index 472327be12781..c4a9009e73c88 100644 --- a/xla/backends/cpu/collectives/gloo_collectives_test.cc +++ b/xla/backends/cpu/collectives/gloo_collectives_test.cc @@ -77,10 +77,10 @@ absl::StatusOr> GetCommunicator( CpuCliqueKey clique_key(global_devices); CpuCollectives::DeviceRank device_rank(nullptr, RankId(rank)); - TF_ASSIGN_OR_RETURN(auto communicators, - collectives->CreateCommunicators( - global_devices.size(), clique_key, std::nullopt, - {device_rank}, CpuCollectives::Config())); + TF_ASSIGN_OR_RETURN( + auto communicators, + collectives->CreateCommunicators(clique_key, std::nullopt, {device_rank}, + CpuCollectives::Config())); return std::move(communicators[0]); } diff --git a/xla/backends/cpu/collectives/in_process_collectives.cc b/xla/backends/cpu/collectives/in_process_collectives.cc index 80227ad7550cc..29bc7752e10e2 100644 --- a/xla/backends/cpu/collectives/in_process_collectives.cc +++ b/xla/backends/cpu/collectives/in_process_collectives.cc @@ -16,7 +16,6 @@ limitations under the License. #include "xla/backends/cpu/collectives/in_process_collectives.h" #include -#include #include #include #include @@ -35,8 +34,7 @@ namespace xla::cpu { absl::StatusOr>> InProcessCollectives::CreateCommunicators( - int32_t nranks, const CliqueKey& clique_key, - const std::optional& clique_id, + const CliqueKey& clique_key, const std::optional& clique_id, absl::Span ranks, const Config& config) { absl::MutexLock lock(&mu_); diff --git a/xla/backends/cpu/collectives/in_process_collectives.h b/xla/backends/cpu/collectives/in_process_collectives.h index 2fd5e53afcf32..11cd32f280ba9 100644 --- a/xla/backends/cpu/collectives/in_process_collectives.h +++ b/xla/backends/cpu/collectives/in_process_collectives.h @@ -37,7 +37,7 @@ namespace xla::cpu { class InProcessCollectives : public CpuCollectives { public: absl::StatusOr>> - CreateCommunicators(int32_t nranks, const CliqueKey& clique_key, + CreateCommunicators(const CliqueKey& clique_key, const std::optional& clique_id, absl::Span ranks, const Config& config) final; diff --git a/xla/backends/cpu/collectives/mpi_collectives.cc b/xla/backends/cpu/collectives/mpi_collectives.cc index 38b2dd1262b8d..c368ed986289f 100644 --- a/xla/backends/cpu/collectives/mpi_collectives.cc +++ b/xla/backends/cpu/collectives/mpi_collectives.cc @@ -16,7 +16,6 @@ limitations under the License. #include "xla/backends/cpu/collectives/mpi_collectives.h" #include -#include #include #include #include @@ -45,7 +44,7 @@ void MpiCollectives::Init() { void MpiCollectives::Finalize() { MPI_Finalize(); } absl::StatusOr>> -MpiCollectives::CreateCommunicators(int32_t nranks, const CliqueKey& clique_key, +MpiCollectives::CreateCommunicators(const CliqueKey& clique_key, const std::optional& clique_id, absl::Span ranks, const Config& config) { diff --git a/xla/backends/cpu/collectives/mpi_collectives.h b/xla/backends/cpu/collectives/mpi_collectives.h index 82722b954121a..702cb05fa4faf 100644 --- a/xla/backends/cpu/collectives/mpi_collectives.h +++ b/xla/backends/cpu/collectives/mpi_collectives.h @@ -48,7 +48,7 @@ class MpiCollectives : public CpuCollectives { void Finalize(); absl::StatusOr>> - CreateCommunicators(int32_t nranks, const CliqueKey& clique_key, + CreateCommunicators(const CliqueKey& clique_key, const std::optional& clique_id, absl::Span ranks, const Config& config) final; diff --git a/xla/backends/gpu/collectives/BUILD b/xla/backends/gpu/collectives/BUILD index 8177dbec9619c..5800d943f2bf6 100644 --- a/xla/backends/gpu/collectives/BUILD +++ b/xla/backends/gpu/collectives/BUILD @@ -106,6 +106,10 @@ cc_library( "//xla/service:lockable", "//xla/service:rendezvous", "//xla/stream_executor:stream_executor_h", + "//xla/tsl/platform:env", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:btree", @@ -119,11 +123,7 @@ cc_library( "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", - "@tsl//tsl/platform:env", - "@tsl//tsl/platform:errors", "@tsl//tsl/platform:hash", - "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:statusor", "@tsl//tsl/profiler/lib:traceme", ], ) @@ -214,6 +214,7 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@tsl//tsl/platform:casts", ] + if_cuda_is_configured([ diff --git a/xla/backends/gpu/collectives/gpu_clique_locking.cc b/xla/backends/gpu/collectives/gpu_clique_locking.cc index afee5ad405bbc..3181122e1227d 100644 --- a/xla/backends/gpu/collectives/gpu_clique_locking.cc +++ b/xla/backends/gpu/collectives/gpu_clique_locking.cc @@ -52,12 +52,12 @@ limitations under the License. #include "xla/service/lockable.h" #include "xla/service/rendezvous.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" #include "xla/util.h" -#include "tsl/platform/env.h" -#include "tsl/platform/errors.h" #include "tsl/platform/hash.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/statusor.h" #include "tsl/profiler/lib/traceme.h" namespace xla::gpu { @@ -197,7 +197,6 @@ InitializeGpuClique(GpuCollectives* collectives, se::StreamExecutor* device, const GpuCollectives::CliqueIdCallback& clique_id_callback, int32_t num_local_participants, RankId rank, const GpuCollectives::Config& config) { - int nranks = clique_key.devices().size(); VLOG(3) << "Initialize GPU clique " << clique_key.ToString() << " rank #" << rank << "; num_local_participants=" << num_local_participants; @@ -240,8 +239,7 @@ InitializeGpuClique(GpuCollectives* collectives, se::StreamExecutor* device, TF_ASSIGN_OR_RETURN( std::vector> created_comms, - collectives->CreateCommunicators(nranks, clique_key, clique_id, ranks, - config)); + collectives->CreateCommunicators(clique_key, clique_id, ranks, config)); absl::btree_map> comms; for (size_t i = 0; i < ranks.size(); ++i) { diff --git a/xla/backends/gpu/collectives/gpu_collectives_stub.h b/xla/backends/gpu/collectives/gpu_collectives_stub.h index ad64b910c6c97..590d085450ee1 100644 --- a/xla/backends/gpu/collectives/gpu_collectives_stub.h +++ b/xla/backends/gpu/collectives/gpu_collectives_stub.h @@ -50,7 +50,7 @@ class GpuCollectivesStub : public GpuCollectives { } absl::StatusOr>> - CreateCommunicators(int32_t, const CliqueKey&, const std::optional&, + CreateCommunicators(const CliqueKey&, const std::optional&, absl::Span, const Collectives::Config&) final { return UnimplementedError(); diff --git a/xla/backends/gpu/collectives/nccl_collectives.cc b/xla/backends/gpu/collectives/nccl_collectives.cc index faa8caf48a6ec..59d0117c325c9 100644 --- a/xla/backends/gpu/collectives/nccl_collectives.cc +++ b/xla/backends/gpu/collectives/nccl_collectives.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/backends/gpu/collectives/gpu_collectives.h" #include "xla/backends/gpu/collectives/nccl_communicator.h" @@ -114,8 +115,7 @@ static absl::StatusOr AsNcclUniqueId(const CliqueId& clique_id) { } absl::StatusOr>> -NcclCollectives::CreateCommunicators(int32_t nranks, - const CliqueKey& clique_key, +NcclCollectives::CreateCommunicators(const CliqueKey& clique_key, const std::optional& clique_id, absl::Span ranks, const Collectives::Config& config) { @@ -139,15 +139,15 @@ NcclCollectives::CreateCommunicators(int32_t nranks, TF_RETURN_IF_ERROR(GroupStart()); for (size_t i = 0; i < ranks.size(); ++i) { VLOG(1) << "Initialize NCCL communicator for rank #" << ranks[i].rank - << " of " << nranks + << " of " << clique_key.num_devices() << "; fingerprint(id)=" << clique_id->fingerprint(); TF_ASSIGN_OR_RETURN(auto* device, TryCast(ranks[i].device)); auto activate_context = device->stream_executor()->Activate(); TF_ASSIGN_OR_RETURN(auto nccl_unique_id, AsNcclUniqueId(*clique_id)); - XLA_NCCL_RETURN_IF_ERROR( - ncclCommInitRankConfig(&comm_handles[i], nranks, nccl_unique_id, - ranks[i].rank.value(), &comm_config)); + XLA_NCCL_RETURN_IF_ERROR(ncclCommInitRankConfig( + &comm_handles[i], clique_key.num_devices(), nccl_unique_id, + ranks[i].rank.value(), &comm_config)); } TF_RETURN_IF_ERROR(GroupEnd()); diff --git a/xla/backends/gpu/collectives/nccl_collectives.h b/xla/backends/gpu/collectives/nccl_collectives.h index c8fb34f627635..721e94d0bc421 100644 --- a/xla/backends/gpu/collectives/nccl_collectives.h +++ b/xla/backends/gpu/collectives/nccl_collectives.h @@ -49,7 +49,7 @@ class NcclCollectives : public GpuCollectives { absl::Status GroupEnd() final; absl::StatusOr>> - CreateCommunicators(int32_t nranks, const CliqueKey& clique_key, + CreateCommunicators(const CliqueKey& clique_key, const std::optional& clique_id, absl::Span ranks, const Collectives::Config& config) final; diff --git a/xla/core/collectives/collectives.h b/xla/core/collectives/collectives.h index 4b41a0dd44081..68f061252b94c 100644 --- a/xla/core/collectives/collectives.h +++ b/xla/core/collectives/collectives.h @@ -70,7 +70,7 @@ class Collectives { // Creates communicators for given clique key and id. virtual absl::StatusOr>> - CreateCommunicators(int32_t nranks, const CliqueKey& clique_key, + CreateCommunicators(const CliqueKey& clique_key, const std::optional& clique_id, absl::Span ranks, const Config& config) = 0;