Skip to content

Commit

Permalink
[XLA:SPMD] support sharding barrier (0/N).
Browse files Browse the repository at this point in the history
Fold shard group sharding instruction attribute into operand instead of replacing with an explicit copy.

PiperOrigin-RevId: 629575145
  • Loading branch information
Tongfei-Guo authored and copybara-github committed May 1, 2024
1 parent 28c5fed commit 8125cf8
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 65 deletions.
9 changes: 5 additions & 4 deletions xla/hlo/ir/hlo_computation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1358,14 +1358,15 @@ Status HloComputation::ReplaceWithNewEntryComputationParameter(

absl::StatusOr<bool> HloComputation::ReplaceInstruction(
HloInstruction* old_instruction, HloInstruction* new_instruction,
bool preserve_sharding, bool relay_control_dependency) {
bool preserve_sharding, bool relay_control_dependency,
bool remove_unused_operands) {
TF_RET_CHECK(
ShapeUtil::Compatible(old_instruction->shape(), new_instruction->shape()))
<< ShapeUtil::HumanString(old_instruction->shape()) << " vs "
<< ShapeUtil::HumanString(new_instruction->shape());
return ReplaceInstructionWithDifferentShape(old_instruction, new_instruction,
preserve_sharding,
relay_control_dependency);
return ReplaceInstructionWithDifferentShape(
old_instruction, new_instruction, preserve_sharding,
relay_control_dependency, remove_unused_operands);
}

Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction,
Expand Down
8 changes: 5 additions & 3 deletions xla/hlo/ir/hlo_computation.h
Original file line number Diff line number Diff line change
Expand Up @@ -581,9 +581,11 @@ class HloComputation {
// return false. Otherwise, when the replacement happens, if |new_instruction|
// doesn't have any sharding information it will receive the sharding
// information of |old_instruction|, and function will return true.
absl::StatusOr<bool> ReplaceInstruction(
HloInstruction* old_instruction, HloInstruction* new_instruction,
bool preserve_sharding, bool relay_control_dependency = false);
absl::StatusOr<bool> ReplaceInstruction(HloInstruction* old_instruction,
HloInstruction* new_instruction,
bool preserve_sharding,
bool relay_control_dependency = false,
bool remove_unused_operands = true);

// Same as above, with preserve_sharding=false. Since this replacement always
// happens, it returns just a Status as opposed to StatusOr<bool>
Expand Down
164 changes: 107 additions & 57 deletions xla/service/sharding_propagation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License.
#include <memory>
#include <optional>
#include <string>
#include <tuple>
#include <utility>
#include <vector>

Expand Down Expand Up @@ -1562,41 +1563,69 @@ absl::StatusOr<bool> ProcessShardingInstruction(
const bool use_shard_group = instruction_to_shard_group_id &&
shard_group_id_to_shard_as_group &&
shard_group_id_to_shard_like_group;
auto process_shard_group_instruction = [&](HloInstruction* instruction,
HloSharding sharding) {
if (use_shard_group && sharding.IsShardGroup()) {
// Store shard group relations.
const int64_t shard_group_id = sharding.GetShardGroup().shard_group_id;
(*instruction_to_shard_group_id)[instruction] = shard_group_id;
if (sharding.IsShardAs()) {
auto& shard_as_group =
(*shard_group_id_to_shard_as_group)[shard_group_id];
if (!shard_as_group.empty()) {
CHECK(ShapeUtil::SameDimensions(instruction->shape(),
(*shard_as_group.begin())->shape()))
<< "Instruction: " << instruction->ToString()
<< " has different shape from the shapes of the other "
"instructions within the same shard_as group: "
<< (*shard_as_group.begin())->shape().ToString();
}
shard_as_group.insert(instruction);
// Process shard group instruction and returns if current instruction needs
// to be removed.
auto process_shard_group_instruction =
[&](HloInstruction* instruction,
bool replaced_with_copy) -> absl::StatusOr<bool> {
if (use_shard_group && instruction->has_sharding() &&
instruction->sharding().IsShardGroup()) {
if (instruction->IsCustomCall("Sharding")) {
CHECK(instruction->operand(0)->opcode() != HloOpcode::kParameter ||
(allow_spmd_sharding_propagation_to_parameters_vector &&
allow_spmd_sharding_propagation_to_parameters_vector->size() ==
module->entry_computation()->num_parameters() &&
allow_spmd_sharding_propagation_to_parameters_vector->at(
instruction->operand(0)->parameter_number())));
}
if (instruction->IsCustomCall("Sharding") && !replaced_with_copy) {
// Pass shard group to operand sharding custom-call if it's not
// replaced with a copy, meaning that the shardings are to annotate
// shard_group or shard_barrier only.
HloSharding operand_sharding = instruction->operand(0)->has_sharding()
? instruction->operand(0)->sharding()
: HloSharding::Unknown();
operand_sharding.SetShardGroup(instruction->sharding().GetShardGroup());
instruction->mutable_operand(0)->set_sharding(operand_sharding);
return true;
} else {
auto& shard_like_group =
(*shard_group_id_to_shard_like_group)[shard_group_id];
if (!shard_like_group.empty()) {
CHECK(ShapeUtil::SameDimensions(instruction->shape(),
(*shard_like_group.begin())->shape()))
<< "Instruction: " << instruction->ToString()
<< " has different shape from the shapes of the other "
"instructions within the same shard_like group: "
<< (*shard_like_group.begin())->shape().ToString();
// Otherwise store the shard group relations.
const int64_t shard_group_id =
instruction->sharding().GetShardGroup().shard_group_id;
(*instruction_to_shard_group_id)[instruction] = shard_group_id;
if (instruction->sharding().IsShardAs()) {
auto& shard_as_group =
(*shard_group_id_to_shard_as_group)[shard_group_id];
if (!shard_as_group.empty()) {
CHECK(ShapeUtil::SameDimensions(instruction->shape(),
(*shard_as_group.begin())->shape()))
<< "Instruction: " << instruction->ToString()
<< " has different shape from the shapes of the other "
"instructions within the same shard_as group: "
<< (*shard_as_group.begin())->shape().ToString();
}
shard_as_group.insert(instruction);
} else {
auto& shard_like_group =
(*shard_group_id_to_shard_like_group)[shard_group_id];
if (!shard_like_group.empty()) {
CHECK(ShapeUtil::SameDimensions(
instruction->shape(), (*shard_like_group.begin())->shape()))
<< "Instruction: " << instruction->ToString()
<< " has different shape from the shapes of the other "
"instructions within the same shard_like group: "
<< (*shard_like_group.begin())->shape().ToString();
}
shard_like_group.insert(instruction);
}
shard_like_group.insert(instruction);
HloSharding sharding = instruction->sharding();
sharding.ClearShardGroup();
instruction->set_sharding(std::move(sharding));
}
sharding.ClearShardGroup();
}
return sharding;
return false;
};

for (HloComputation* computation : module->computations(execution_threads)) {
auto instructions = computation->MakeInstructionPostOrder();
for (auto it = instructions.rbegin(); it != instructions.rend(); ++it) {
Expand All @@ -1612,44 +1641,48 @@ absl::StatusOr<bool> ProcessShardingInstruction(
Cast<HloCustomCallInstruction>(instruction)->opaque(),
&unspec_dims));

// Replace it with a copy node so that it does not need special
// handling.
if (replace_sharding_with_copy) {
bool replaced_with_copy =
replace_sharding_with_copy &&
(!original_sharding.IsUnknown() ||
instruction->operand(0)->opcode() == HloOpcode::kParameter);
// Replace the sharding instruction with a copy node so that it does not
// need special handling.
if (replaced_with_copy) {
auto copy = computation->AddInstruction(HloInstruction::CreateUnary(
instruction->shape(), HloOpcode::kCopy,
instruction->mutable_operand(0)));
TF_RETURN_IF_ERROR(
computation->ReplaceInstruction(instruction, copy));
// Add into shard group.
HloSharding sharding =
process_shard_group_instruction(copy, original_sharding);
copy->set_sharding(sharding);
TF_ASSIGN_OR_RETURN(
std::ignore, computation->ReplaceInstruction(
instruction, copy, /*preserve_sharding=*/false,
/*relay_control_dependency=*/false,
/*remove_unused_operands=*/false));
copy->set_sharding(original_sharding);
instruction = copy;
changed = true;
}
// Strip the sharding of the shard group related annotations.

TF_ASSIGN_OR_RETURN(
bool shard_group_remove_instruction,
process_shard_group_instruction(instruction, replaced_with_copy));
if (!unspec_dims.empty()) {
absl::c_sort(unspec_dims);
unspecified_dims->emplace(instruction, std::move(unspec_dims));
} else if (!instruction->operand(0)->has_sharding()) {
HloSharding sharding = original_sharding;
if (instruction->operand(0)->opcode() != HloOpcode::kParameter ||
(allow_spmd_sharding_propagation_to_parameters_vector &&
allow_spmd_sharding_propagation_to_parameters_vector->size() ==
module->entry_computation()->num_parameters() &&
allow_spmd_sharding_propagation_to_parameters_vector->at(
instruction->operand(0)->parameter_number()))) {
// Add operand(i.e. the annotated op) into shard group.
sharding = process_shard_group_instruction(
instruction->mutable_operand(0), sharding);
}
instruction->mutable_operand(0)->set_sharding(std::move(sharding));
instruction->mutable_operand(0)->set_sharding(
instruction->sharding());
}
} else if (instruction->has_sharding()) {
// Handle shard group in parameters/outputs.
HloSharding sharding = process_shard_group_instruction(
instruction, instruction->sharding());
instruction->set_sharding(std::move(sharding));
if (shard_group_remove_instruction) {
TF_ASSIGN_OR_RETURN(std::ignore,
computation->ReplaceInstruction(
instruction, instruction->mutable_operand(0),
/*preserve_sharding=*/false,
/*relay_control_dependency=*/false,
/*remove_unused_operands=*/false));
}
} else {
TF_ASSIGN_OR_RETURN(std::ignore,
process_shard_group_instruction(
instruction, /*replaced_with_copy=*/false));
}
}
}
Expand Down Expand Up @@ -2975,6 +3008,23 @@ absl::StatusOr<bool> ShardingPropagation::Run(
&shard_group_id_to_shard_like_group,
&allow_spmd_sharding_propagation_to_parameters_vector_));
any_changed |= changed;

for (const auto& [shard_group_id, shard_as_group] :
shard_group_id_to_shard_as_group) {
VLOG(5) << "Shard-As group " << shard_group_id << " contains:";
for (auto instruction : shard_as_group) {
VLOG(5) << " " << instruction->ToString();
}
}

for (const auto& [shard_group_id, shard_like_group] :
shard_group_id_to_shard_like_group) {
VLOG(5) << "Shard-Like group " << shard_group_id << " contains:";
for (auto instruction : shard_like_group) {
VLOG(5) << " " << instruction->ToString();
}
}

// Check sizes of the given allow_spmd_sharding_propagation vectors
if (allow_spmd_sharding_propagation_to_output_) {
CHECK(!module->entry_computation()->root_instruction()->has_sharding() ||
Expand Down
2 changes: 1 addition & 1 deletion xla/service/sharding_propagation_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10776,7 +10776,7 @@ TEST_F(ShardingPropagationTest, PropagateShardAsBetweenInputOutput2) {
HloModule jit_f, entry_computation_layout={(f32[8]{0:T(256)})->(f32[8]{0:T(256)}, f32[8]{0:T(256)})}, allow_spmd_sharding_propagation_to_output={true,true}, num_partitions=4
ENTRY main.9 {
Arg_0.1 = f32[8]{0} parameter(0), sharding={replicated}
Arg_0.1 = f32[8]{0} parameter(0)
custom-call.6 = f32[8]{0} custom-call(Arg_0.1), custom_call_target="Sharding", custom_call_has_side_effect=true, sharding={unknown shard_as 0}, metadata={op_name="jit(f)/jit(main)/shard_alike" source_file="third_party/py/jax/tests/shard_alike_test.py" source_line=206}
custom-call.4 = f32[8]{0} custom-call(Arg_0.1), custom_call_target="Sharding", sharding={devices=[4]<=[4]}, metadata={op_name="jit(f)/jit(main)/sharding_constraint[sharding=GSPMDSharding({devices=[4]<=[4]}) resource_env=ResourceEnv(mesh=Mesh(), ()) unconstrained_dims=set()]" source_file="third_party/py/jax/tests/shard_alike_test.py" source_line=204}
constant.0 = f32[] constant(2)
Expand Down

0 comments on commit 8125cf8

Please sign in to comment.