Skip to content

Commit

Permalink
[xla:cpu] Migrate AllReduce to RendezvousSingle API
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 713821152
  • Loading branch information
ezhulenev authored and Google-ML-Automation committed Jan 9, 2025
1 parent 5a6ef8a commit 38f3209
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 176 deletions.
2 changes: 1 addition & 1 deletion xla/backends/cpu/collectives/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ cc_library(
"//xla/service:global_device_id",
"//xla/service:rendezvous",
"//xla/stream_executor:device_memory",
"//xla/tsl/platform:errors",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/log",
Expand All @@ -153,7 +154,6 @@ cc_library(
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/time",
"@com_google_absl//absl/types:span",
"@tsl//tsl/platform:errors",
],
)

Expand Down
302 changes: 127 additions & 175 deletions xla/backends/cpu/collectives/in_process_communicator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ limitations under the License.
#include "xla/service/global_device_id.h"
#include "xla/service/rendezvous.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/util.h"
#include "xla/xla_data.pb.h"
Expand All @@ -59,113 +60,6 @@ void FormatGlobalId(std::string* out, const GlobalDeviceId& device) {
absl::StrAppend(out, device.value());
}

//===----------------------------------------------------------------------===//
// AllGather
//===----------------------------------------------------------------------===//

struct AllGatherParticipant {
size_t rank;
se::DeviceMemoryBase src;
se::DeviceMemoryBase dest;
};

static absl::Status AllGatherOp(
size_t num_bytes, absl::Span<const AllGatherParticipant*> participants) {
absl::c_sort(participants, ByRank<AllGatherParticipant>);

size_t num_participants = participants.size();

for (size_t i = 0; i < num_participants; ++i) {
for (size_t j = 0; j < num_participants; ++j) {
std::byte* dest = static_cast<std::byte*>(participants[i]->dest.opaque());
size_t offset = j * num_bytes;
std::memcpy(dest + offset, participants[j]->src.opaque(), num_bytes);
}
}

return absl::OkStatus();
}

//===----------------------------------------------------------------------===//
// AllToAll
//===----------------------------------------------------------------------===//

struct AllToAllParticipant {
size_t rank;

std::vector<se::DeviceMemoryBase> src;
std::vector<se::DeviceMemoryBase> dest;
};

static absl::Status AllToAllOp(
size_t num_bytes, absl::Span<const AllToAllParticipant*> participants) {
absl::c_sort(participants, ByRank<AllToAllParticipant>);

size_t num_participants = participants.size();

for (size_t i = 0; i < num_participants; ++i) {
for (size_t j = 0; j < num_participants; ++j) {
std::memcpy(participants[j]->dest[i].opaque(),
participants[i]->src[j].opaque(), num_bytes);
}
}

return absl::OkStatus();
}

//===----------------------------------------------------------------------===//
// CollectivePermute
//===----------------------------------------------------------------------===//

struct CollectivePermuteParticipant {
size_t rank;
std::optional<RankId> src_rank;

se::DeviceMemoryBase src;
se::DeviceMemoryBase dest;
};

static absl::Status CollectivePermuteOp(
size_t num_bytes,
absl::Span<const CollectivePermuteParticipant*> participants) {
absl::c_sort(participants, ByRank<CollectivePermuteParticipant>);

for (const CollectivePermuteParticipant* participant : participants) {
void* dest = participant->dest.opaque();

if (participant->src_rank) {
size_t src_rank = participant->src_rank->value();
std::memcpy(dest, participants.at(src_rank)->src.opaque(), num_bytes);
} else {
std::memset(dest, 0, num_bytes);
}
}
return absl::OkStatus();
}

//===----------------------------------------------------------------------===//

struct AllReduceParticipantData : ParticipantData {
explicit AllReduceParticipantData(const RendezvousKey& rendezvous_key_p,
int rank)
: ParticipantData(rendezvous_key_p, rank) {}

int64_t element_count;
const void* source_data;
void* destination_data;
PrimitiveType primitive_type;

ReductionKind reduction_kind;

std::string ToString() const override {
return absl::StrFormat(
"AllReduceParticipantData{rank=%d, element_count=%d, type=%s, "
"rendezvous_key=%s}",
local_rank, element_count, PrimitiveType_Name(primitive_type),
rendezvous_key.ToString());
}
};

template <typename T>
T GetInitialValue(ReductionKind reduction_kind) {
switch (reduction_kind) {
Expand Down Expand Up @@ -266,65 +160,136 @@ absl::Status ReduceScatter(ReductionKind reduction_kind,
return absl::OkStatus();
}

class CpuAllReduceRendezvous
: public Rendezvous<AllReduceParticipantData, std::nullptr_t> {
public:
explicit CpuAllReduceRendezvous(const RendezvousKey& k)
: Rendezvous<AllReduceParticipantData, std::nullptr_t>(k) {}
//===----------------------------------------------------------------------===//
// AllReduce
//===----------------------------------------------------------------------===//

protected:
absl::StatusOr<std::nullptr_t> RunCollectiveOp(
const AllReduceParticipantData& me) override {
VLOG(3) << me.ToString();
int64_t world_size = participants_.size();
// Divide the buffer up into equal(ish) chunks. Rank r computes the r-th
// chunk of the output.
int64_t chunk_elems = CeilOfRatio(me.element_count, world_size);

int64_t start_elem = me.local_rank * chunk_elems;
int64_t end_elem = std::min(start_elem + chunk_elems, me.element_count);
chunk_elems = std::max(int64_t{0}, end_elem - start_elem);
if (chunk_elems == 0) {
return nullptr;
}
struct AllReduceParticipant {
size_t rank;
se::DeviceMemoryBase src;
se::DeviceMemoryBase dest;
};

auto bytes_per_elem = primitive_util::ByteWidth(me.primitive_type);
int64_t chunk_offset = start_elem * bytes_per_elem;
int64_t chunk_bytes = chunk_elems * bytes_per_elem;
void* reduce_output =
reinterpret_cast<char*>(me.destination_data) + chunk_offset;
static absl::Status AllReduceOp(
PrimitiveType primitive_type, size_t count, ReductionKind reduction_kind,
absl::Span<const AllReduceParticipant*> participants) {
absl::c_sort(participants, ByRank<AllReduceParticipant>);

std::vector<const void*> inputs;
inputs.reserve(world_size);
for (const auto& p : participants_) {
inputs.push_back(reinterpret_cast<const char*>(p->source_data) +
chunk_offset);
}
if (!primitive_util::IsArrayType(primitive_type)) {
return Unimplemented(
"Unexpected datatype: %s",
primitive_util::LowercasePrimitiveTypeName(primitive_type));
}

if (primitive_util::IsArrayType(me.primitive_type)) {
TF_RETURN_IF_ERROR(primitive_util::ArrayTypeSwitch<absl::Status>(
[&](const auto constant_type) {
return ReduceScatter<constant_type>(me.reduction_kind, inputs,
reduce_output, chunk_elems);
},
me.primitive_type));
} else {
return absl::UnimplementedError(absl::StrCat(
"Unexpected datatype: ",
primitive_util::LowercasePrimitiveTypeName(me.primitive_type)));
// Reduce all inputs into a single output at rank 0.
std::vector<const void*> inputs(participants.size());
for (auto* participant : participants) {
inputs[participant->rank] = participant->src.opaque();
}
void* output = participants[0]->dest.opaque();

TF_RETURN_IF_ERROR(primitive_util::ArrayTypeSwitch<absl::Status>(
[&](const auto constant_type) {
return ReduceScatter<constant_type>(reduction_kind, inputs, output,
count);
},
primitive_type));

// Copy all-reduced output to all other participants.
for (size_t i = 1; i < participants.size(); ++i) {
std::memcpy(participants[i]->dest.opaque(), participants[0]->dest.opaque(),
count * primitive_util::ByteWidth(primitive_type));
}

return absl::OkStatus();
}

//===----------------------------------------------------------------------===//
// AllGather
//===----------------------------------------------------------------------===//

struct AllGatherParticipant {
size_t rank;
se::DeviceMemoryBase src;
se::DeviceMemoryBase dest;
};

static absl::Status AllGatherOp(
size_t num_bytes, absl::Span<const AllGatherParticipant*> participants) {
absl::c_sort(participants, ByRank<AllGatherParticipant>);

size_t num_participants = participants.size();

for (size_t i = 0; i < num_participants; ++i) {
for (size_t j = 0; j < num_participants; ++j) {
std::byte* dest = static_cast<std::byte*>(participants[i]->dest.opaque());
size_t offset = j * num_bytes;
std::memcpy(dest + offset, participants[j]->src.opaque(), num_bytes);
}
}

// All-gather the reduced chunks.
for (const auto& p : participants_) {
if (p->local_rank != me.local_rank) {
std::memcpy(reinterpret_cast<char*>(p->destination_data) + chunk_offset,
reduce_output, chunk_bytes);
}
return absl::OkStatus();
}

//===----------------------------------------------------------------------===//
// AllToAll
//===----------------------------------------------------------------------===//

struct AllToAllParticipant {
size_t rank;

std::vector<se::DeviceMemoryBase> src;
std::vector<se::DeviceMemoryBase> dest;
};

static absl::Status AllToAllOp(
size_t num_bytes, absl::Span<const AllToAllParticipant*> participants) {
absl::c_sort(participants, ByRank<AllToAllParticipant>);

size_t num_participants = participants.size();

for (size_t i = 0; i < num_participants; ++i) {
for (size_t j = 0; j < num_participants; ++j) {
std::memcpy(participants[j]->dest[i].opaque(),
participants[i]->src[j].opaque(), num_bytes);
}
return nullptr;
}

return absl::OkStatus();
}

//===----------------------------------------------------------------------===//
// CollectivePermute
//===----------------------------------------------------------------------===//

struct CollectivePermuteParticipant {
size_t rank;
std::optional<RankId> src_rank;

se::DeviceMemoryBase src;
se::DeviceMemoryBase dest;
};

static absl::Status CollectivePermuteOp(
size_t num_bytes,
absl::Span<const CollectivePermuteParticipant*> participants) {
absl::c_sort(participants, ByRank<CollectivePermuteParticipant>);

for (const CollectivePermuteParticipant* participant : participants) {
void* dest = participant->dest.opaque();

if (participant->src_rank) {
size_t src_rank = participant->src_rank->value();
std::memcpy(dest, participants.at(src_rank)->src.opaque(), num_bytes);
} else {
std::memset(dest, 0, num_bytes);
}
}
return absl::OkStatus();
}

//===----------------------------------------------------------------------===//

struct ReduceScatterParticipantData : ParticipantData {
ReduceScatterParticipantData(const RendezvousKey& rendezvous_key_p, int rank)
: ParticipantData(rendezvous_key_p, rank) {}
Expand Down Expand Up @@ -385,8 +350,6 @@ class CpuReduceScatterRendezvous
} // namespace

struct InProcessCommunicator::State {
RefcountingHashMap<RendezvousKey, CpuAllReduceRendezvous>
all_reduce_rendezvous_map;
RefcountingHashMap<RendezvousKey, CpuReduceScatterRendezvous>
reduce_scatter_rendezvous_map;
};
Expand All @@ -410,24 +373,13 @@ absl::Status InProcessCommunicator::AllReduce(se::DeviceMemoryBase send_buffer,
TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor));
const RendezvousKey& key = cpu_executor->rendezvous_key();

AllReduceParticipantData participant(key, rank_);
participant.element_count = count;
participant.primitive_type = dtype;
participant.source_data = send_buffer.opaque();
participant.destination_data = recv_buffer.opaque();
participant.reduction_kind = reduction_kind;
std::string name = absl::StrCat("all reduce ", key.ToString());
AllReduceParticipant partiticipant{rank_, send_buffer, recv_buffer};

auto make_cpu_rendezvous = [](const RendezvousKey& k) {
return std::make_unique<CpuAllReduceRendezvous>(k);
};

return CpuAllReduceRendezvous::SubmitParticipant(
[&] {
return state_->all_reduce_rendezvous_map.GetOrCreateIfAbsent(
key, make_cpu_rendezvous);
},
participant)
.status();
return RendezvousSingle<absl::Status>(
name, key, partiticipant, key.num_local_participants,
std::bind(AllReduceOp, dtype, count, reduction_kind,
std::placeholders::_1));
}

absl::Status InProcessCommunicator::CollectivePermute(
Expand Down

0 comments on commit 38f3209

Please sign in to comment.