diff --git a/xla/hlo/transforms/BUILD b/xla/hlo/transforms/BUILD index 1fc1ab06ce71f..b306a6dfcbb73 100644 --- a/xla/hlo/transforms/BUILD +++ b/xla/hlo/transforms/BUILD @@ -394,9 +394,12 @@ cc_library( "//xla/hlo/transforms/simplifiers:float_normalization", "//xla/hlo/transforms/simplifiers:hlo_memory_scheduler", "//xla/hlo/transforms/simplifiers:sub_byte_normalization", + "//xla/service:host_memory_offload_annotations_hdr", + "//xla/tsl/platform:errors", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", ], diff --git a/xla/hlo/transforms/despecializer.cc b/xla/hlo/transforms/despecializer.cc index 97b2e5d8bd478..83ef51bc90fcb 100644 --- a/xla/hlo/transforms/despecializer.cc +++ b/xla/hlo/transforms/despecializer.cc @@ -23,12 +23,17 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" #include "absl/log/log.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/transforms/defuser.h" #include "xla/hlo/transforms/simplifiers/float_normalization.h" #include "xla/hlo/transforms/simplifiers/hlo_memory_scheduler.h" #include "xla/hlo/transforms/simplifiers/sub_byte_normalization.h" +#include "xla/service/host_memory_offload_annotations.h" +#include "xla/tsl/platform/errors.h" #include "xla/xla_data.pb.h" namespace xla { @@ -47,6 +52,10 @@ void Despecializer::AddAssumeGatherIndicesInBoundRewriteToCopy() { pipeline_.AddPass(); } +void Despecializer::AddHostMemoryOffloadRewriteToCopy() { + pipeline_.AddPass(); +} + void Despecializer::AddReduceWindowToReduceBroadcastDeconstruct() { pipeline_.AddPass(); } @@ -57,6 +66,21 @@ absl::StatusOr Despecializer::Run( return pipeline_.Run(module, execution_threads); } +namespace { + +absl::Status RewriteToCopy(const std::vector& candidates) { + for (HloInstruction* gather_indices : candidates) { + auto computation = gather_indices->parent(); + auto copy = computation->AddInstruction( + HloInstruction::CreateUnary(gather_indices->shape(), HloOpcode::kCopy, + gather_indices->mutable_operand(0))); + TF_RETURN_IF_ERROR(computation->ReplaceInstruction(gather_indices, copy)); + } + return absl::OkStatus(); +} + +} // namespace + // AssumeGatherIndicesInBoundRewriteToCopy is needed to handle the // "AssumeGatherIndicesInBound" custom-call in a gather fusion. // "AssumeGatherIndicesInBound" custom-call is a @@ -75,13 +99,29 @@ absl::StatusOr AssumeGatherIndicesInBoundRewriteToCopy::Run( } } } - for (HloInstruction* gather_indices : candidates) { - auto computation = gather_indices->parent(); - auto copy = computation->AddInstruction( - HloInstruction::CreateUnary(gather_indices->shape(), HloOpcode::kCopy, - gather_indices->mutable_operand(0))); - TF_CHECK_OK(computation->ReplaceInstruction(gather_indices, copy)); + TF_RETURN_IF_ERROR(RewriteToCopy(candidates)); + return !candidates.empty(); +} + +absl::StatusOr HostMemoryOffloadRewriteToCopy::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + std::vector candidates; + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->IsCustomCall(host_memory_offload_annotations:: + kPinToDeviceSramCustomCallTarget) || + instruction->IsCustomCall( + host_memory_offload_annotations::kPinToDeviceCustomCallTarget) || + instruction->IsCustomCall( + host_memory_offload_annotations::kMoveToDeviceCustomCallTarget) || + instruction->IsCustomCall( + host_memory_offload_annotations::kMoveToHostCustomCallTarget)) { + candidates.push_back(instruction); + } + } } + TF_RETURN_IF_ERROR(RewriteToCopy(candidates)); return !candidates.empty(); } diff --git a/xla/hlo/transforms/despecializer.h b/xla/hlo/transforms/despecializer.h index 1266eed7af03f..791d4e1992804 100644 --- a/xla/hlo/transforms/despecializer.h +++ b/xla/hlo/transforms/despecializer.h @@ -37,13 +37,20 @@ namespace xla { // optimized for one specific platform on a different platform (undoing platform // specific passes) with matching numerics for comparison. // -// Current despecialization passes are HloDescheduler, ControlDepRemover, -// Defuser and BFloat16MixedPrecisionRemoval. +// Current despecialization passes are +// - HloDescheduler +// - Defuser +// - BFloat16MixedPrecisionRemoval +// - ControlDepRemover +// - DeconstructReduceWindowToReduceBroadcast +// - AssumeGatherIndicesInBoundRewriteToCopy +// - HostMemoryOffloadRewriteToCopy class Despecializer : public HloModulePass { public: Despecializer(); void AddReduceWindowToReduceBroadcastDeconstruct(); void AddAssumeGatherIndicesInBoundRewriteToCopy(); + void AddHostMemoryOffloadRewriteToCopy(); absl::string_view name() const override { return "despecializer"; } using HloPassInterface::Run; absl::StatusOr Run( @@ -66,6 +73,18 @@ class AssumeGatherIndicesInBoundRewriteToCopy : public HloModulePass { const absl::flat_hash_set& execution_threads) override; }; +class HostMemoryOffloadRewriteToCopy : public HloModulePass { + public: + HostMemoryOffloadRewriteToCopy() = default; + absl::string_view name() const override { + return "HostMemoryOffloadRewriteToCopy"; + } + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; + class DeconstructReduceWindowToReduceBroadcast : public HloModulePass { public: DeconstructReduceWindowToReduceBroadcast() = default; diff --git a/xla/literal_util.cc b/xla/literal_util.cc index c689d7eb74ad2..fe458d54d1880 100644 --- a/xla/literal_util.cc +++ b/xla/literal_util.cc @@ -717,7 +717,6 @@ absl::StatusOr MakeFakeLiteral( Shape new_shape = shape; new_shape.mutable_layout()->clear_tiles(); new_shape.mutable_layout()->set_tail_padding_alignment_in_elements(1); - new_shape.mutable_layout()->set_element_size_in_bits(0); Literal literal(new_shape); TF_RETURN_IF_ERROR(primitive_util::PrimitiveTypeSwitch(