diff --git a/xla/backends/cpu/collectives/BUILD b/xla/backends/cpu/collectives/BUILD index 08faed4fae0a7..301bbc087b0bd 100644 --- a/xla/backends/cpu/collectives/BUILD +++ b/xla/backends/cpu/collectives/BUILD @@ -144,6 +144,7 @@ cc_library( "//xla/service:global_device_id", "//xla/service:rendezvous", "//xla/stream_executor:device_memory", + "//xla/tsl/platform:errors", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", @@ -153,7 +154,6 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", - "@tsl//tsl/platform:errors", ], ) diff --git a/xla/backends/cpu/collectives/in_process_communicator.cc b/xla/backends/cpu/collectives/in_process_communicator.cc index 10d177db2d144..4b856593311c4 100644 --- a/xla/backends/cpu/collectives/in_process_communicator.cc +++ b/xla/backends/cpu/collectives/in_process_communicator.cc @@ -43,6 +43,7 @@ limitations under the License. #include "xla/service/global_device_id.h" #include "xla/service/rendezvous.h" #include "xla/stream_executor/device_memory.h" +#include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/statusor.h" #include "xla/util.h" #include "xla/xla_data.pb.h" @@ -59,113 +60,6 @@ void FormatGlobalId(std::string* out, const GlobalDeviceId& device) { absl::StrAppend(out, device.value()); } -//===----------------------------------------------------------------------===// -// AllGather -//===----------------------------------------------------------------------===// - -struct AllGatherParticipant { - size_t rank; - se::DeviceMemoryBase src; - se::DeviceMemoryBase dest; -}; - -static absl::Status AllGatherOp( - size_t num_bytes, absl::Span participants) { - absl::c_sort(participants, ByRank); - - size_t num_participants = participants.size(); - - for (size_t i = 0; i < num_participants; ++i) { - for (size_t j = 0; j < num_participants; ++j) { - std::byte* dest = static_cast(participants[i]->dest.opaque()); - size_t offset = j * num_bytes; - std::memcpy(dest + offset, participants[j]->src.opaque(), num_bytes); - } - } - - return absl::OkStatus(); -} - -//===----------------------------------------------------------------------===// -// AllToAll -//===----------------------------------------------------------------------===// - -struct AllToAllParticipant { - size_t rank; - - std::vector src; - std::vector dest; -}; - -static absl::Status AllToAllOp( - size_t num_bytes, absl::Span participants) { - absl::c_sort(participants, ByRank); - - size_t num_participants = participants.size(); - - for (size_t i = 0; i < num_participants; ++i) { - for (size_t j = 0; j < num_participants; ++j) { - std::memcpy(participants[j]->dest[i].opaque(), - participants[i]->src[j].opaque(), num_bytes); - } - } - - return absl::OkStatus(); -} - -//===----------------------------------------------------------------------===// -// CollectivePermute -//===----------------------------------------------------------------------===// - -struct CollectivePermuteParticipant { - size_t rank; - std::optional src_rank; - - se::DeviceMemoryBase src; - se::DeviceMemoryBase dest; -}; - -static absl::Status CollectivePermuteOp( - size_t num_bytes, - absl::Span participants) { - absl::c_sort(participants, ByRank); - - for (const CollectivePermuteParticipant* participant : participants) { - void* dest = participant->dest.opaque(); - - if (participant->src_rank) { - size_t src_rank = participant->src_rank->value(); - std::memcpy(dest, participants.at(src_rank)->src.opaque(), num_bytes); - } else { - std::memset(dest, 0, num_bytes); - } - } - return absl::OkStatus(); -} - -//===----------------------------------------------------------------------===// - -struct AllReduceParticipantData : ParticipantData { - explicit AllReduceParticipantData(const RendezvousKey& rendezvous_key_p, - int rank) - : ParticipantData(rendezvous_key_p, rank) {} - - int64_t element_count; - const void* source_data; - void* destination_data; - PrimitiveType primitive_type; - - ReductionKind reduction_kind; - - std::string ToString() const override { - return absl::StrFormat( - "AllReduceParticipantData{rank=%d, element_count=%d, type=%s, " - "rendezvous_key=%s}", - local_rank, element_count, PrimitiveType_Name(primitive_type), - rendezvous_key.ToString()); - } -}; - template T GetInitialValue(ReductionKind reduction_kind) { switch (reduction_kind) { @@ -266,65 +160,136 @@ absl::Status ReduceScatter(ReductionKind reduction_kind, return absl::OkStatus(); } -class CpuAllReduceRendezvous - : public Rendezvous { - public: - explicit CpuAllReduceRendezvous(const RendezvousKey& k) - : Rendezvous(k) {} +//===----------------------------------------------------------------------===// +// AllReduce +//===----------------------------------------------------------------------===// - protected: - absl::StatusOr RunCollectiveOp( - const AllReduceParticipantData& me) override { - VLOG(3) << me.ToString(); - int64_t world_size = participants_.size(); - // Divide the buffer up into equal(ish) chunks. Rank r computes the r-th - // chunk of the output. - int64_t chunk_elems = CeilOfRatio(me.element_count, world_size); - - int64_t start_elem = me.local_rank * chunk_elems; - int64_t end_elem = std::min(start_elem + chunk_elems, me.element_count); - chunk_elems = std::max(int64_t{0}, end_elem - start_elem); - if (chunk_elems == 0) { - return nullptr; - } +struct AllReduceParticipant { + size_t rank; + se::DeviceMemoryBase src; + se::DeviceMemoryBase dest; +}; - auto bytes_per_elem = primitive_util::ByteWidth(me.primitive_type); - int64_t chunk_offset = start_elem * bytes_per_elem; - int64_t chunk_bytes = chunk_elems * bytes_per_elem; - void* reduce_output = - reinterpret_cast(me.destination_data) + chunk_offset; +static absl::Status AllReduceOp( + PrimitiveType primitive_type, size_t count, ReductionKind reduction_kind, + absl::Span participants) { + absl::c_sort(participants, ByRank); - std::vector inputs; - inputs.reserve(world_size); - for (const auto& p : participants_) { - inputs.push_back(reinterpret_cast(p->source_data) + - chunk_offset); - } + if (!primitive_util::IsArrayType(primitive_type)) { + return Unimplemented( + "Unexpected datatype: %s", + primitive_util::LowercasePrimitiveTypeName(primitive_type)); + } - if (primitive_util::IsArrayType(me.primitive_type)) { - TF_RETURN_IF_ERROR(primitive_util::ArrayTypeSwitch( - [&](const auto constant_type) { - return ReduceScatter(me.reduction_kind, inputs, - reduce_output, chunk_elems); - }, - me.primitive_type)); - } else { - return absl::UnimplementedError(absl::StrCat( - "Unexpected datatype: ", - primitive_util::LowercasePrimitiveTypeName(me.primitive_type))); + // Reduce all inputs into a single output at rank 0. + std::vector inputs(participants.size()); + for (auto* participant : participants) { + inputs[participant->rank] = participant->src.opaque(); + } + void* output = participants[0]->dest.opaque(); + + TF_RETURN_IF_ERROR(primitive_util::ArrayTypeSwitch( + [&](const auto constant_type) { + return ReduceScatter(reduction_kind, inputs, output, + count); + }, + primitive_type)); + + // Copy all-reduced output to all other participants. + for (size_t i = 1; i < participants.size(); ++i) { + std::memcpy(participants[i]->dest.opaque(), participants[0]->dest.opaque(), + count * primitive_util::ByteWidth(primitive_type)); + } + + return absl::OkStatus(); +} + +//===----------------------------------------------------------------------===// +// AllGather +//===----------------------------------------------------------------------===// + +struct AllGatherParticipant { + size_t rank; + se::DeviceMemoryBase src; + se::DeviceMemoryBase dest; +}; + +static absl::Status AllGatherOp( + size_t num_bytes, absl::Span participants) { + absl::c_sort(participants, ByRank); + + size_t num_participants = participants.size(); + + for (size_t i = 0; i < num_participants; ++i) { + for (size_t j = 0; j < num_participants; ++j) { + std::byte* dest = static_cast(participants[i]->dest.opaque()); + size_t offset = j * num_bytes; + std::memcpy(dest + offset, participants[j]->src.opaque(), num_bytes); } + } - // All-gather the reduced chunks. - for (const auto& p : participants_) { - if (p->local_rank != me.local_rank) { - std::memcpy(reinterpret_cast(p->destination_data) + chunk_offset, - reduce_output, chunk_bytes); - } + return absl::OkStatus(); +} + +//===----------------------------------------------------------------------===// +// AllToAll +//===----------------------------------------------------------------------===// + +struct AllToAllParticipant { + size_t rank; + + std::vector src; + std::vector dest; +}; + +static absl::Status AllToAllOp( + size_t num_bytes, absl::Span participants) { + absl::c_sort(participants, ByRank); + + size_t num_participants = participants.size(); + + for (size_t i = 0; i < num_participants; ++i) { + for (size_t j = 0; j < num_participants; ++j) { + std::memcpy(participants[j]->dest[i].opaque(), + participants[i]->src[j].opaque(), num_bytes); } - return nullptr; } + + return absl::OkStatus(); +} + +//===----------------------------------------------------------------------===// +// CollectivePermute +//===----------------------------------------------------------------------===// + +struct CollectivePermuteParticipant { + size_t rank; + std::optional src_rank; + + se::DeviceMemoryBase src; + se::DeviceMemoryBase dest; }; +static absl::Status CollectivePermuteOp( + size_t num_bytes, + absl::Span participants) { + absl::c_sort(participants, ByRank); + + for (const CollectivePermuteParticipant* participant : participants) { + void* dest = participant->dest.opaque(); + + if (participant->src_rank) { + size_t src_rank = participant->src_rank->value(); + std::memcpy(dest, participants.at(src_rank)->src.opaque(), num_bytes); + } else { + std::memset(dest, 0, num_bytes); + } + } + return absl::OkStatus(); +} + +//===----------------------------------------------------------------------===// + struct ReduceScatterParticipantData : ParticipantData { ReduceScatterParticipantData(const RendezvousKey& rendezvous_key_p, int rank) : ParticipantData(rendezvous_key_p, rank) {} @@ -385,8 +350,6 @@ class CpuReduceScatterRendezvous } // namespace struct InProcessCommunicator::State { - RefcountingHashMap - all_reduce_rendezvous_map; RefcountingHashMap reduce_scatter_rendezvous_map; }; @@ -410,24 +373,13 @@ absl::Status InProcessCommunicator::AllReduce(se::DeviceMemoryBase send_buffer, TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor)); const RendezvousKey& key = cpu_executor->rendezvous_key(); - AllReduceParticipantData participant(key, rank_); - participant.element_count = count; - participant.primitive_type = dtype; - participant.source_data = send_buffer.opaque(); - participant.destination_data = recv_buffer.opaque(); - participant.reduction_kind = reduction_kind; + std::string name = absl::StrCat("all reduce ", key.ToString()); + AllReduceParticipant partiticipant{rank_, send_buffer, recv_buffer}; - auto make_cpu_rendezvous = [](const RendezvousKey& k) { - return std::make_unique(k); - }; - - return CpuAllReduceRendezvous::SubmitParticipant( - [&] { - return state_->all_reduce_rendezvous_map.GetOrCreateIfAbsent( - key, make_cpu_rendezvous); - }, - participant) - .status(); + return RendezvousSingle( + name, key, partiticipant, key.num_local_participants, + std::bind(AllReduceOp, dtype, count, reduction_kind, + std::placeholders::_1)); } absl::Status InProcessCommunicator::CollectivePermute(