From 2f6eabb5a1d0a4ce5ba9eb0d52620463b3ece2c3 Mon Sep 17 00:00:00 2001 From: Zixuan Jiang Date: Thu, 9 Jan 2025 08:27:19 -0800 Subject: [PATCH] Rewrite `Reshard(HloSharding::Replicate())` as `Replicate()` for `PartitionedHlo`. PiperOrigin-RevId: 713681703 --- xla/service/spmd/convolution_handler.cc | 2 +- xla/service/spmd/dot_handler.cc | 14 ++++---- xla/service/spmd/spmd_partitioner.cc | 48 +++++++++---------------- 3 files changed, 23 insertions(+), 41 deletions(-) diff --git a/xla/service/spmd/convolution_handler.cc b/xla/service/spmd/convolution_handler.cc index a084c2ec98fae..aaf27dcd30c19 100644 --- a/xla/service/spmd/convolution_handler.cc +++ b/xla/service/spmd/convolution_handler.cc @@ -793,7 +793,7 @@ absl::StatusOr PartitionConvolutionTiledOutput( lhs = lhs.Reshard(target_operand_sharding); // Replicate the RHS. - rhs = rhs.Reshard(HloSharding::Replicate()); + rhs = rhs.Replicate(); // Convolution window config does not include batch and feature dimensions, // whereas ReshardAsWindowedInput() expects the same number of window diff --git a/xla/service/spmd/dot_handler.cc b/xla/service/spmd/dot_handler.cc index 5a6d1ca7e3351..ef619b7719e7e 100644 --- a/xla/service/spmd/dot_handler.cc +++ b/xla/service/spmd/dot_handler.cc @@ -580,13 +580,13 @@ std::optional GetWindowedEinsumConfiguration( ? PartitionedHlo(partitioned_lhs->hlo(), partitioned_lhs->base_shape(), partitioned_lhs->state()) - .Reshard(HloSharding::Replicate()) + .Replicate() : *partitioned_lhs; auto new_rhs = rhs_needs_ag ? PartitionedHlo(partitioned_rhs->hlo(), partitioned_rhs->base_shape(), partitioned_rhs->state()) - .Reshard(HloSharding::Replicate()) + .Replicate() : *partitioned_rhs; dot = (*create_sharded_dot)(new_lhs.hlo(), new_rhs.hlo(), b, conv_window) .value(); @@ -2017,16 +2017,14 @@ absl::StatusOr PartitionBaseCase( if (lhs_non_contracting_partitions == num_partitions && output_lhs_non_contracting_partitions == num_partitions && lhs_sharding_transposed_to_match_output == output_sharding) { - auto rhs_replicated = rhs.Reshard(HloSharding::Replicate()).hlo(); - return create_sharded_dot(lhs.hlo(), rhs_replicated, b, conv_window); + return create_sharded_dot(lhs.hlo(), rhs.Replicate().hlo(), b, conv_window); } // RHS and output have the same partitioned non-contracting dimensions. if (rhs_non_contracting_partitions == num_partitions && output_rhs_non_contracting_partitions == num_partitions && rhs_sharding_transposed_to_match_output == output_sharding) { - auto lhs_replicated = lhs.Reshard(HloSharding::Replicate()).hlo(); - return create_sharded_dot(lhs_replicated, rhs.hlo(), b, conv_window); + return create_sharded_dot(lhs.Replicate().hlo(), rhs.hlo(), b, conv_window); } if (may_reshard_without_detecting_match) { @@ -2043,13 +2041,13 @@ absl::StatusOr PartitionBaseCase( if (output_lhs_non_contracting_partitions == num_partitions) { auto resharded_lhs = lhs.Reshard(*output_sharding_transposed_to_match_lhs); - auto replicated_rhs = rhs.Reshard(HloSharding::Replicate()); + auto replicated_rhs = rhs.Replicate(); return create_sharded_dot(resharded_lhs.hlo(), replicated_rhs.hlo(), b, conv_window); } // Output is partitioned along RHS non-contracting dimensions. if (output_rhs_non_contracting_partitions == num_partitions) { - auto replicated_lhs = lhs.Reshard(HloSharding::Replicate()); + auto replicated_lhs = lhs.Replicate(); auto resharded_rhs = rhs.Reshard(*output_sharding_transposed_to_match_rhs); return create_sharded_dot(replicated_lhs.hlo(), resharded_rhs.hlo(), b, diff --git a/xla/service/spmd/spmd_partitioner.cc b/xla/service/spmd/spmd_partitioner.cc index b4f09c7dbbc31..3072cefc28a4e 100644 --- a/xla/service/spmd/spmd_partitioner.cc +++ b/xla/service/spmd/spmd_partitioner.cc @@ -432,7 +432,7 @@ PartitionedHlo PartitionedHlo::Reshard(const HloSharding& target, // propagated to constant.) if (hlo()->opcode() == HloOpcode::kConstant && !sharding().IsManual() && target.IsManual()) { - PartitionedHlo pconstant = this->Reshard(HloSharding::Replicate()); + PartitionedHlo pconstant = this->Replicate(); pconstant.hlo()->set_sharding(target); return pconstant; } @@ -2913,8 +2913,7 @@ absl::Status SpmdPartitioningVisitor::HandleSort(HloInstruction* hlo) { slice_input, ShapeUtil::MakeShape(element_type, replicated_dimensions), MakePartitioningState()); // Reshard value to be replicated. - auto replicated_slice_input = - partitioned_slice_input.Reshard(HloSharding::Replicate()).hlo(); + auto replicated_slice_input = partitioned_slice_input.Replicate().hlo(); // Slice top K index from the first parttioned sort. auto slice_index = SliceFirstK(index_gte, &b_, sort_dim, k.value()); @@ -2923,8 +2922,7 @@ absl::Status SpmdPartitioningVisitor::HandleSort(HloInstruction* hlo) { slice_index, ShapeUtil::MakeShape(index_type, replicated_dimensions), MakePartitioningState()); // Reshard value to be replicated. - auto replicated_slice_index = - partitioned_slice_index.Reshard(HloSharding::Replicate()).hlo(); + auto replicated_slice_index = partitioned_slice_index.Replicate().hlo(); // Creates replicated sort to do TopK, the input is value and index pairs // from all the partitions. @@ -3566,9 +3564,7 @@ absl::Status SpmdPartitioningVisitor::HandleDynamicSlice(HloInstruction* hlo) { continue; } // Replicate the indices.; - new_indices[i] = GetPartitionedHlo(hlo->operand(i + 1)) - .Reshard(HloSharding::Replicate()) - .hlo(); + new_indices[i] = GetPartitionedHlo(hlo->operand(i + 1)).Replicate().hlo(); } SetPartitionedHlo(hlo, [&]() { auto partitioned_shape = @@ -3623,9 +3619,7 @@ absl::Status SpmdPartitioningVisitor::HandleDynamicUpdateSlice( std::vector new_indices(hlo->shape().rank()); for (int64_t i = 0; i < new_indices.size(); ++i) { // Replicate the indices. - new_indices[i] = GetPartitionedHlo(hlo->operand(i + 2)) - .Reshard(HloSharding::Replicate()) - .hlo(); + new_indices[i] = GetPartitionedHlo(hlo->operand(i + 2)).Replicate().hlo(); } auto dus = b_.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( base.hlo()->shape(), base.hlo(), operand.hlo(), new_indices)); @@ -3654,9 +3648,7 @@ absl::Status SpmdPartitioningVisitor::HandleDynamicUpdateSlice( continue; } // Replicate the indices. - new_indices[i] = GetPartitionedHlo(hlo->operand(i + 2)) - .Reshard(HloSharding::Replicate()) - .hlo(); + new_indices[i] = GetPartitionedHlo(hlo->operand(i + 2)).Replicate().hlo(); } // Get partitioned input. @@ -3774,9 +3766,7 @@ absl::Status SpmdPartitioningVisitor::HandleDynamicUpdateSlice( continue; } // Replicate the indices. - new_indices[i] = GetPartitionedHlo(hlo->operand(i + 2)) - .Reshard(HloSharding::Replicate()) - .hlo(); + new_indices[i] = GetPartitionedHlo(hlo->operand(i + 2)).Replicate().hlo(); } SetPartitionedHlo(hlo, [&]() { auto partitioned_shape = @@ -3944,9 +3934,7 @@ absl::Status SpmdPartitioningVisitor::HandlePad(HloInstruction* hlo) { return DefaultAction(hlo); } auto lhs = GetPartitionedHlo(hlo->operand(0)); - auto replicated_rhs = GetPartitionedHlo(hlo->operand(1)) - .Reshard(HloSharding::Replicate()) - .hlo(); + auto replicated_rhs = GetPartitionedHlo(hlo->operand(1)).Replicate().hlo(); auto reshard_operand = ReshardDataForPad( replicated_rhs, hlo->padding_config(), lhs, hlo->sharding(), &b_); if (!reshard_operand.has_value()) { @@ -4025,7 +4013,7 @@ absl::Status SpmdPartitioningVisitor::HandleReduce(HloInstruction* hlo) { for (int64_t operand_id = 0; operand_id < input_count; ++operand_id) { inits.push_back(GetPartitionedHlo(hlo->operand(operand_id + input_count)) - .Reshard(HloSharding::Replicate()) + .Replicate() .hlo()); inputs.push_back(GetPartitionedHlo(hlo->operand(operand_id))); if (operand_id > 0) { @@ -4210,9 +4198,7 @@ absl::Status SpmdPartitioningVisitor::HandleConditional(HloInstruction* hlo) { .Reshard(hlo_sharding_util::UngroupSharding(grouped_sharding)) .hlo(); } else { - cond = GetPartitionedHlo(hlo->operand(0)) - .Reshard(HloSharding::Replicate()) - .hlo(); + cond = GetPartitionedHlo(hlo->operand(0)).Replicate().hlo(); } } return b_.AddInstruction(HloInstruction::CreateConditional( @@ -4438,7 +4424,7 @@ absl::Status SpmdPartitioningVisitor::HandleRng(HloInstruction* hlo) { // Run on a single device (0) and distribute the data to all other cores. auto clone = clone_from_original(HloSharding::AssignDevice(0)); return PartitionedHlo(clone, hlo->shape(), MakePartitioningState()) - .Reshard(HloSharding::Replicate()) + .Replicate() .hlo(); }); return absl::OkStatus(); @@ -4449,9 +4435,8 @@ absl::Status SpmdPartitioningVisitor::HandleRng(HloInstruction* hlo) { std::vector new_operands; new_operands.reserve(hlo->operand_count()); for (int64_t i = 0; i < hlo->operand_count(); ++i) { - new_operands.push_back(GetPartitionedHlo(hlo->operand(i)) - .Reshard(HloSharding::Replicate()) - .hlo()); + new_operands.push_back( + GetPartitionedHlo(hlo->operand(i)).Replicate().hlo()); } if (!hlo->sharding().ReplicateOnLastTileDim()) { @@ -4498,8 +4483,8 @@ absl::Status SpmdPartitioningVisitor::HandleReduceWindow(HloInstruction* hlo) { for (const HloInstruction* input_array : input_arrays) { PartitionedHlo& operand = GetPartitionedHlo(input_array); // Replicate init - PartitionedHlo replicated_init = GetPartitionedHlo(init_values[input_idx]) - .Reshard(HloSharding::Replicate()); + PartitionedHlo replicated_init = + GetPartitionedHlo(init_values[input_idx]).Replicate(); const HloSharding& sharding = hlo->sharding().IsTuple() ? hlo->sharding().tuple_elements()[input_idx] @@ -4601,8 +4586,7 @@ absl::Status SpmdPartitioningVisitor::HandleSelectAndScatter( : LiteralUtil::CreateR0(float_pad_value))); // Replicate init - auto replicated_init = GetPartitionedHlo(hlo->mutable_operand(2)) - .Reshard(HloSharding::Replicate()); + auto replicated_init = GetPartitionedHlo(hlo->mutable_operand(2)).Replicate(); auto state = MakePartitioningState(); auto partition_ordinals =