Skip to content

Commit

Permalink
Add host offloading rewriter to despecializer
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 713029784
  • Loading branch information
Google-ML-Automation committed Feb 3, 2025
1 parent 6743f54 commit 549fd62
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 9 deletions.
3 changes: 3 additions & 0 deletions xla/hlo/transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -394,9 +394,12 @@ cc_library(
"//xla/hlo/transforms/simplifiers:float_normalization",
"//xla/hlo/transforms/simplifiers:hlo_memory_scheduler",
"//xla/hlo/transforms/simplifiers:sub_byte_normalization",
"//xla/service:host_memory_offload_annotations_hdr",
"//xla/tsl/platform:errors",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
],
Expand Down
52 changes: 46 additions & 6 deletions xla/hlo/transforms/despecializer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,17 @@ limitations under the License.
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_set.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/transforms/defuser.h"
#include "xla/hlo/transforms/simplifiers/float_normalization.h"
#include "xla/hlo/transforms/simplifiers/hlo_memory_scheduler.h"
#include "xla/hlo/transforms/simplifiers/sub_byte_normalization.h"
#include "xla/service/host_memory_offload_annotations.h"
#include "xla/tsl/platform/errors.h"
#include "xla/xla_data.pb.h"

namespace xla {
Expand All @@ -47,6 +52,10 @@ void Despecializer::AddAssumeGatherIndicesInBoundRewriteToCopy() {
pipeline_.AddPass<AssumeGatherIndicesInBoundRewriteToCopy>();
}

void Despecializer::AddHostMemoryOffloadRewriteToCopy() {
pipeline_.AddPass<HostMemoryOffloadRewriteToCopy>();
}

void Despecializer::AddReduceWindowToReduceBroadcastDeconstruct() {
pipeline_.AddPass<DeconstructReduceWindowToReduceBroadcast>();
}
Expand All @@ -57,6 +66,21 @@ absl::StatusOr<bool> Despecializer::Run(
return pipeline_.Run(module, execution_threads);
}

namespace {

absl::Status RewriteToCopy(const std::vector<HloInstruction*>& candidates) {
for (HloInstruction* gather_indices : candidates) {
auto computation = gather_indices->parent();
auto copy = computation->AddInstruction(
HloInstruction::CreateUnary(gather_indices->shape(), HloOpcode::kCopy,
gather_indices->mutable_operand(0)));
TF_RETURN_IF_ERROR(computation->ReplaceInstruction(gather_indices, copy));
}
return absl::OkStatus();
}

} // namespace

// AssumeGatherIndicesInBoundRewriteToCopy is needed to handle the
// "AssumeGatherIndicesInBound" custom-call in a gather fusion.
// "AssumeGatherIndicesInBound" custom-call is a
Expand All @@ -75,13 +99,29 @@ absl::StatusOr<bool> AssumeGatherIndicesInBoundRewriteToCopy::Run(
}
}
}
for (HloInstruction* gather_indices : candidates) {
auto computation = gather_indices->parent();
auto copy = computation->AddInstruction(
HloInstruction::CreateUnary(gather_indices->shape(), HloOpcode::kCopy,
gather_indices->mutable_operand(0)));
TF_CHECK_OK(computation->ReplaceInstruction(gather_indices, copy));
TF_RETURN_IF_ERROR(RewriteToCopy(candidates));
return !candidates.empty();
}

absl::StatusOr<bool> HostMemoryOffloadRewriteToCopy::Run(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) {
std::vector<HloInstruction*> candidates;
for (HloComputation* computation : module->computations()) {
for (HloInstruction* instruction : computation->instructions()) {
if (instruction->IsCustomCall(host_memory_offload_annotations::
kPinToDeviceSramCustomCallTarget) ||
instruction->IsCustomCall(
host_memory_offload_annotations::kPinToDeviceCustomCallTarget) ||
instruction->IsCustomCall(
host_memory_offload_annotations::kMoveToDeviceCustomCallTarget) ||
instruction->IsCustomCall(
host_memory_offload_annotations::kMoveToHostCustomCallTarget)) {
candidates.push_back(instruction);
}
}
}
TF_RETURN_IF_ERROR(RewriteToCopy(candidates));
return !candidates.empty();
}

Expand Down
23 changes: 21 additions & 2 deletions xla/hlo/transforms/despecializer.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,20 @@ namespace xla {
// optimized for one specific platform on a different platform (undoing platform
// specific passes) with matching numerics for comparison.
//
// Current despecialization passes are HloDescheduler, ControlDepRemover,
// Defuser and BFloat16MixedPrecisionRemoval.
// Current despecialization passes are
// - HloDescheduler
// - Defuser
// - BFloat16MixedPrecisionRemoval
// - ControlDepRemover
// - DeconstructReduceWindowToReduceBroadcast
// - AssumeGatherIndicesInBoundRewriteToCopy
// - HostMemoryOffloadRewriteToCopy
class Despecializer : public HloModulePass {
public:
Despecializer();
void AddReduceWindowToReduceBroadcastDeconstruct();
void AddAssumeGatherIndicesInBoundRewriteToCopy();
void AddHostMemoryOffloadRewriteToCopy();
absl::string_view name() const override { return "despecializer"; }
using HloPassInterface::Run;
absl::StatusOr<bool> Run(
Expand All @@ -66,6 +73,18 @@ class AssumeGatherIndicesInBoundRewriteToCopy : public HloModulePass {
const absl::flat_hash_set<absl::string_view>& execution_threads) override;
};

class HostMemoryOffloadRewriteToCopy : public HloModulePass {
public:
HostMemoryOffloadRewriteToCopy() = default;
absl::string_view name() const override {
return "HostMemoryOffloadRewriteToCopy";
}
using HloPassInterface::Run;
absl::StatusOr<bool> Run(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) override;
};

class DeconstructReduceWindowToReduceBroadcast : public HloModulePass {
public:
DeconstructReduceWindowToReduceBroadcast() = default;
Expand Down
1 change: 0 additions & 1 deletion xla/literal_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -717,7 +717,6 @@ absl::StatusOr<Literal> MakeFakeLiteral(
Shape new_shape = shape;
new_shape.mutable_layout()->clear_tiles();
new_shape.mutable_layout()->set_tail_padding_alignment_in_elements(1);
new_shape.mutable_layout()->set_element_size_in_bits(0);
Literal literal(new_shape);

TF_RETURN_IF_ERROR(primitive_util::PrimitiveTypeSwitch<absl::Status>(
Expand Down

0 comments on commit 549fd62

Please sign in to comment.