Skip to content

Commit

Permalink
Rewrite Reshard(HloSharding::Replicate()) as Replicate() for `Par…
Browse files Browse the repository at this point in the history
…titionedHlo`.

PiperOrigin-RevId: 713681703
  • Loading branch information
ZixuanJiang authored and Google-ML-Automation committed Jan 9, 2025
1 parent ddb08a6 commit 2f6eabb
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 41 deletions.
2 changes: 1 addition & 1 deletion xla/service/spmd/convolution_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -793,7 +793,7 @@ absl::StatusOr<HloInstruction*> PartitionConvolutionTiledOutput(
lhs = lhs.Reshard(target_operand_sharding);

// Replicate the RHS.
rhs = rhs.Reshard(HloSharding::Replicate());
rhs = rhs.Replicate();

// Convolution window config does not include batch and feature dimensions,
// whereas ReshardAsWindowedInput() expects the same number of window
Expand Down
14 changes: 6 additions & 8 deletions xla/service/spmd/dot_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -580,13 +580,13 @@ std::optional<WindowedEinsumConfig> GetWindowedEinsumConfiguration(
? PartitionedHlo(partitioned_lhs->hlo(),
partitioned_lhs->base_shape(),
partitioned_lhs->state())
.Reshard(HloSharding::Replicate())
.Replicate()
: *partitioned_lhs;
auto new_rhs = rhs_needs_ag
? PartitionedHlo(partitioned_rhs->hlo(),
partitioned_rhs->base_shape(),
partitioned_rhs->state())
.Reshard(HloSharding::Replicate())
.Replicate()
: *partitioned_rhs;
dot = (*create_sharded_dot)(new_lhs.hlo(), new_rhs.hlo(), b, conv_window)
.value();
Expand Down Expand Up @@ -2017,16 +2017,14 @@ absl::StatusOr<HloInstruction*> PartitionBaseCase(
if (lhs_non_contracting_partitions == num_partitions &&
output_lhs_non_contracting_partitions == num_partitions &&
lhs_sharding_transposed_to_match_output == output_sharding) {
auto rhs_replicated = rhs.Reshard(HloSharding::Replicate()).hlo();
return create_sharded_dot(lhs.hlo(), rhs_replicated, b, conv_window);
return create_sharded_dot(lhs.hlo(), rhs.Replicate().hlo(), b, conv_window);
}

// RHS and output have the same partitioned non-contracting dimensions.
if (rhs_non_contracting_partitions == num_partitions &&
output_rhs_non_contracting_partitions == num_partitions &&
rhs_sharding_transposed_to_match_output == output_sharding) {
auto lhs_replicated = lhs.Reshard(HloSharding::Replicate()).hlo();
return create_sharded_dot(lhs_replicated, rhs.hlo(), b, conv_window);
return create_sharded_dot(lhs.Replicate().hlo(), rhs.hlo(), b, conv_window);
}

if (may_reshard_without_detecting_match) {
Expand All @@ -2043,13 +2041,13 @@ absl::StatusOr<HloInstruction*> PartitionBaseCase(
if (output_lhs_non_contracting_partitions == num_partitions) {
auto resharded_lhs =
lhs.Reshard(*output_sharding_transposed_to_match_lhs);
auto replicated_rhs = rhs.Reshard(HloSharding::Replicate());
auto replicated_rhs = rhs.Replicate();
return create_sharded_dot(resharded_lhs.hlo(), replicated_rhs.hlo(), b,
conv_window);
}
// Output is partitioned along RHS non-contracting dimensions.
if (output_rhs_non_contracting_partitions == num_partitions) {
auto replicated_lhs = lhs.Reshard(HloSharding::Replicate());
auto replicated_lhs = lhs.Replicate();
auto resharded_rhs =
rhs.Reshard(*output_sharding_transposed_to_match_rhs);
return create_sharded_dot(replicated_lhs.hlo(), resharded_rhs.hlo(), b,
Expand Down
48 changes: 16 additions & 32 deletions xla/service/spmd/spmd_partitioner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ PartitionedHlo PartitionedHlo::Reshard(const HloSharding& target,
// propagated to constant.)
if (hlo()->opcode() == HloOpcode::kConstant && !sharding().IsManual() &&
target.IsManual()) {
PartitionedHlo pconstant = this->Reshard(HloSharding::Replicate());
PartitionedHlo pconstant = this->Replicate();
pconstant.hlo()->set_sharding(target);
return pconstant;
}
Expand Down Expand Up @@ -2913,8 +2913,7 @@ absl::Status SpmdPartitioningVisitor::HandleSort(HloInstruction* hlo) {
slice_input, ShapeUtil::MakeShape(element_type, replicated_dimensions),
MakePartitioningState());
// Reshard value to be replicated.
auto replicated_slice_input =
partitioned_slice_input.Reshard(HloSharding::Replicate()).hlo();
auto replicated_slice_input = partitioned_slice_input.Replicate().hlo();

// Slice top K index from the first parttioned sort.
auto slice_index = SliceFirstK(index_gte, &b_, sort_dim, k.value());
Expand All @@ -2923,8 +2922,7 @@ absl::Status SpmdPartitioningVisitor::HandleSort(HloInstruction* hlo) {
slice_index, ShapeUtil::MakeShape(index_type, replicated_dimensions),
MakePartitioningState());
// Reshard value to be replicated.
auto replicated_slice_index =
partitioned_slice_index.Reshard(HloSharding::Replicate()).hlo();
auto replicated_slice_index = partitioned_slice_index.Replicate().hlo();

// Creates replicated sort to do TopK, the input is value and index pairs
// from all the partitions.
Expand Down Expand Up @@ -3566,9 +3564,7 @@ absl::Status SpmdPartitioningVisitor::HandleDynamicSlice(HloInstruction* hlo) {
continue;
}
// Replicate the indices.;
new_indices[i] = GetPartitionedHlo(hlo->operand(i + 1))
.Reshard(HloSharding::Replicate())
.hlo();
new_indices[i] = GetPartitionedHlo(hlo->operand(i + 1)).Replicate().hlo();
}
SetPartitionedHlo(hlo, [&]() {
auto partitioned_shape =
Expand Down Expand Up @@ -3623,9 +3619,7 @@ absl::Status SpmdPartitioningVisitor::HandleDynamicUpdateSlice(
std::vector<HloInstruction*> new_indices(hlo->shape().rank());
for (int64_t i = 0; i < new_indices.size(); ++i) {
// Replicate the indices.
new_indices[i] = GetPartitionedHlo(hlo->operand(i + 2))
.Reshard(HloSharding::Replicate())
.hlo();
new_indices[i] = GetPartitionedHlo(hlo->operand(i + 2)).Replicate().hlo();
}
auto dus = b_.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
base.hlo()->shape(), base.hlo(), operand.hlo(), new_indices));
Expand Down Expand Up @@ -3654,9 +3648,7 @@ absl::Status SpmdPartitioningVisitor::HandleDynamicUpdateSlice(
continue;
}
// Replicate the indices.
new_indices[i] = GetPartitionedHlo(hlo->operand(i + 2))
.Reshard(HloSharding::Replicate())
.hlo();
new_indices[i] = GetPartitionedHlo(hlo->operand(i + 2)).Replicate().hlo();
}

// Get partitioned input.
Expand Down Expand Up @@ -3774,9 +3766,7 @@ absl::Status SpmdPartitioningVisitor::HandleDynamicUpdateSlice(
continue;
}
// Replicate the indices.
new_indices[i] = GetPartitionedHlo(hlo->operand(i + 2))
.Reshard(HloSharding::Replicate())
.hlo();
new_indices[i] = GetPartitionedHlo(hlo->operand(i + 2)).Replicate().hlo();
}
SetPartitionedHlo(hlo, [&]() {
auto partitioned_shape =
Expand Down Expand Up @@ -3944,9 +3934,7 @@ absl::Status SpmdPartitioningVisitor::HandlePad(HloInstruction* hlo) {
return DefaultAction(hlo);
}
auto lhs = GetPartitionedHlo(hlo->operand(0));
auto replicated_rhs = GetPartitionedHlo(hlo->operand(1))
.Reshard(HloSharding::Replicate())
.hlo();
auto replicated_rhs = GetPartitionedHlo(hlo->operand(1)).Replicate().hlo();
auto reshard_operand = ReshardDataForPad(
replicated_rhs, hlo->padding_config(), lhs, hlo->sharding(), &b_);
if (!reshard_operand.has_value()) {
Expand Down Expand Up @@ -4025,7 +4013,7 @@ absl::Status SpmdPartitioningVisitor::HandleReduce(HloInstruction* hlo) {

for (int64_t operand_id = 0; operand_id < input_count; ++operand_id) {
inits.push_back(GetPartitionedHlo(hlo->operand(operand_id + input_count))
.Reshard(HloSharding::Replicate())
.Replicate()
.hlo());
inputs.push_back(GetPartitionedHlo(hlo->operand(operand_id)));
if (operand_id > 0) {
Expand Down Expand Up @@ -4210,9 +4198,7 @@ absl::Status SpmdPartitioningVisitor::HandleConditional(HloInstruction* hlo) {
.Reshard(hlo_sharding_util::UngroupSharding(grouped_sharding))
.hlo();
} else {
cond = GetPartitionedHlo(hlo->operand(0))
.Reshard(HloSharding::Replicate())
.hlo();
cond = GetPartitionedHlo(hlo->operand(0)).Replicate().hlo();
}
}
return b_.AddInstruction(HloInstruction::CreateConditional(
Expand Down Expand Up @@ -4438,7 +4424,7 @@ absl::Status SpmdPartitioningVisitor::HandleRng(HloInstruction* hlo) {
// Run on a single device (0) and distribute the data to all other cores.
auto clone = clone_from_original(HloSharding::AssignDevice(0));
return PartitionedHlo(clone, hlo->shape(), MakePartitioningState())
.Reshard(HloSharding::Replicate())
.Replicate()
.hlo();
});
return absl::OkStatus();
Expand All @@ -4449,9 +4435,8 @@ absl::Status SpmdPartitioningVisitor::HandleRng(HloInstruction* hlo) {
std::vector<HloInstruction*> new_operands;
new_operands.reserve(hlo->operand_count());
for (int64_t i = 0; i < hlo->operand_count(); ++i) {
new_operands.push_back(GetPartitionedHlo(hlo->operand(i))
.Reshard(HloSharding::Replicate())
.hlo());
new_operands.push_back(
GetPartitionedHlo(hlo->operand(i)).Replicate().hlo());
}

if (!hlo->sharding().ReplicateOnLastTileDim()) {
Expand Down Expand Up @@ -4498,8 +4483,8 @@ absl::Status SpmdPartitioningVisitor::HandleReduceWindow(HloInstruction* hlo) {
for (const HloInstruction* input_array : input_arrays) {
PartitionedHlo& operand = GetPartitionedHlo(input_array);
// Replicate init
PartitionedHlo replicated_init = GetPartitionedHlo(init_values[input_idx])
.Reshard(HloSharding::Replicate());
PartitionedHlo replicated_init =
GetPartitionedHlo(init_values[input_idx]).Replicate();

const HloSharding& sharding =
hlo->sharding().IsTuple() ? hlo->sharding().tuple_elements()[input_idx]
Expand Down Expand Up @@ -4601,8 +4586,7 @@ absl::Status SpmdPartitioningVisitor::HandleSelectAndScatter(
: LiteralUtil::CreateR0<float>(float_pad_value)));

// Replicate init
auto replicated_init = GetPartitionedHlo(hlo->mutable_operand(2))
.Reshard(HloSharding::Replicate());
auto replicated_init = GetPartitionedHlo(hlo->mutable_operand(2)).Replicate();

auto state = MakePartitioningState();
auto partition_ordinals =
Expand Down

0 comments on commit 2f6eabb

Please sign in to comment.