Skip to content

Commit

Permalink
Add fake OpaqueExecutable and update HLO runner interface use it.
Browse files Browse the repository at this point in the history
We want to migrate all uses of `xla::Executable` that interact with the HLO
runners to `xla::OpaqueExecutable`. This will be a new class that is not a
member of the `xla::Executable` class hierarchy. The plan is for this class to
have no public fields or accessors and for it to solely be used for wrapping
runner-specific executables within.

This is step 1/3.

PiperOrigin-RevId: 722727338
  • Loading branch information
nvgrw authored and Google-ML-Automation committed Feb 3, 2025
1 parent a42a623 commit 264b591
Show file tree
Hide file tree
Showing 7 changed files with 175 additions and 48 deletions.
20 changes: 18 additions & 2 deletions xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -4645,11 +4645,11 @@ cc_library(
"//xla/tsl/platform:env",
"//xla/tsl/platform:errors",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/base:nullability",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:span",
"@tsl//tsl/platform:statusor",
],
)

Expand All @@ -4660,24 +4660,40 @@ cc_library(
deps = [
":backend",
":compiler",
":computation_layout",
":computation_placer",
":executable",
":hlo_module_util",
":hlo_runner_interface",
":maybe_owning_device_memory",
":shaped_buffer",
":transfer_manager",
"//xla:executable_run_options",
"//xla:literal",
"//xla:shape_tree",
"//xla:shape_util",
"//xla:status_macros",
"//xla:types",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/hlo/ir:hlo_module_group",
"//xla/service/gpu:gpu_executable_run_options",
"//xla/stream_executor:device_description",
"//xla/stream_executor:device_memory",
"//xla/stream_executor:device_memory_allocator",
"//xla/stream_executor:platform",
"//xla/stream_executor:stream_executor_h",
"//xla/stream_executor:stream_executor_memory_allocator",
"//xla/tsl/platform:env",
"//xla/tsl/platform:errors",
"//xla/tsl/platform:logging",
"//xla/tsl/platform:status",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/base:nullability",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:span",
"@eigen_archive//:eigen3",
],
Expand Down
78 changes: 67 additions & 11 deletions xla/service/hlo_runner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,52 @@ limitations under the License.

#include "xla/service/hlo_runner.h"

#include <cstdint>
#include <functional>
#include <memory>
#include <string>
#include <utility>
#include <variant>
#include <vector>

#include "absl/base/nullability.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
#include "absl/types/span.h"
#include "unsupported/Eigen/CXX11/Tensor"
#include "xla/executable_run_options.h"
#include "xla/hlo/ir/hlo_input_output_alias_config.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/ir/hlo_module_group.h"
#include "xla/literal.h"
#include "xla/service/backend.h"
#include "xla/service/computation_placer.h"
#include "xla/service/executable.h"
#include "xla/service/gpu/gpu_executable_run_options.h"
#include "xla/service/hlo_module_util.h"
#include "xla/service/hlo_runner_interface.h"
#include "xla/service/maybe_owning_device_memory.h"
#include "xla/service/service_executable_run_options.h"
#include "xla/service/shaped_buffer.h"
#include "xla/service/transfer_manager.h"
#include "xla/shape.h"
#include "xla/shape_tree.h"
#include "xla/shape_util.h"
#include "xla/status_macros.h"
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/device_memory_allocator.h"
#include "xla/stream_executor/platform.h"
#include "xla/stream_executor/stream_executor.h"
#include "xla/stream_executor/stream_executor_memory_allocator.h"
#include "xla/tsl/platform/env.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/logging.h"
#include "xla/tsl/platform/status.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/tsl/platform/threadpool.h"

namespace xla {

Expand Down Expand Up @@ -186,7 +214,7 @@ absl::StatusOr<Literal> HloRunner::ExecuteWithBufferAssignment(
}

absl::StatusOr<Literal> HloRunner::ExecuteWithExecutable(
Executable* executable, absl::Span<const Literal* const> arguments,
OpaqueExecutable* executable, absl::Span<const Literal* const> arguments,
ExecutionProfile* profile) {
entry_computation_layout_ =
&(executable->module().entry_computation_layout());
Expand Down Expand Up @@ -286,14 +314,14 @@ absl::StatusOr<ExecutionOutput> HloRunner::ExecuteWithDeviceBuffers(
std::unique_ptr<HloModule> module,
absl::Span<ScopedShapedBuffer const> arguments, bool run_hlo_passes,
ExecutionProfile* profile) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
TF_ASSIGN_OR_RETURN(std::unique_ptr<OpaqueExecutable> executable,
CreateExecutable(std::move(module), run_hlo_passes));
return ExecuteWithDeviceBuffers(executable.get(), arguments, profile);
}

absl::StatusOr<ExecutionOutput> HloRunner::ExecuteWithDeviceBuffers(
Executable* executable, absl::Span<ScopedShapedBuffer const> arguments,
ExecutionProfile* profile) {
OpaqueExecutable* executable,
absl::Span<ScopedShapedBuffer const> arguments, ExecutionProfile* profile) {
std::vector<ExecutionInput> execution_arguments =
ExecutionInputsFromScopedShapedBuffers(
arguments, executable->module().input_output_alias_config(),
Expand All @@ -319,7 +347,7 @@ HloRunner::ExecuteWithMovedDeviceBuffersAndBufferAssignment(
std::vector<ScopedShapedBuffer> arguments, bool run_hlo_passes,
ExecutionProfile* profile) {
TF_ASSIGN_OR_RETURN(
std::unique_ptr<Executable> executable,
std::unique_ptr<OpaqueExecutable> executable,
CreateExecutableWithBufferAssignment(
std::move(module), buffer_assignment_proto, run_hlo_passes));
return ExecuteWithMovedDeviceBuffers(executable.get(), std::move(arguments),
Expand Down Expand Up @@ -384,7 +412,7 @@ absl::StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicated(
std::unique_ptr<HloModule> module, const ReplicatedExecuteOptions& options,
DeviceAssignment* device_assignment) {
TF_ASSIGN_OR_RETURN(
std::unique_ptr<Executable> executable,
std::unique_ptr<OpaqueExecutable> executable,
CreateExecutable(std::move(module), options.run_hlo_passes));
return ExecuteReplicated(executable.get(), options, device_assignment);
}
Expand Down Expand Up @@ -529,7 +557,7 @@ absl::StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicatedImpl(
}

absl::StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicated(
Executable* executable, const ReplicatedExecuteOptions& options,
OpaqueExecutable* executable, const ReplicatedExecuteOptions& options,
DeviceAssignment* device_assignment, ExecutionProfile* profile) {
return ExecuteReplicatedImpl(
[&](const std::vector<ServiceExecutableRunOptions>& service_run_options,
Expand Down Expand Up @@ -577,7 +605,7 @@ absl::StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicated(
}

absl::StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicated(
std::function<Executable*(int64_t)> executable_provider,
std::function<OpaqueExecutable*(int64_t)> executable_provider,
std::function<int64_t(int64_t)> argument_count_provider,
std::function<const Literal*(int64_t, int64_t)> argument_provider,
const ReplicatedExecuteOptions& options,
Expand Down Expand Up @@ -640,14 +668,14 @@ absl::StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicated(
return ExecuteReplicated(std::move(module), options, &device_assignment);
}

absl::StatusOr<std::unique_ptr<Executable>> HloRunner::CreateExecutable(
absl::StatusOr<std::unique_ptr<OpaqueExecutable>> HloRunner::CreateExecutable(
std::unique_ptr<HloModule> module, bool run_hlo_passes) {
return CreateExecutableWithBufferAssignment(
std::move(module),
/*buffer_assignment_proto=*/nullptr, run_hlo_passes);
}

absl::StatusOr<std::unique_ptr<Executable>>
absl::StatusOr<std::unique_ptr<OpaqueExecutable>>
HloRunner::CreateExecutableWithBufferAssignment(
std::unique_ptr<HloModule> module,
const BufferAssignmentProto* buffer_assignment_proto, bool run_hlo_passes) {
Expand All @@ -665,7 +693,7 @@ HloRunner::CreateExecutableWithBufferAssignment(
}
auto module_group = std::make_unique<HloModuleGroup>(std::move(module));
TF_ASSIGN_OR_RETURN(
auto executables,
std::vector<std::unique_ptr<Executable>> executables,
backend().compiler()->Compile(std::move(module_group),
{{backend().default_stream_executor()}},
backend().memory_allocator()));
Expand Down Expand Up @@ -724,4 +752,32 @@ bool HloRunner::HasProperty(const HloRunnerPropertyTag::Type tag) const {
return false;
}

absl::StatusOr<Executable*> HloRunner::ExecutableFromWrapped(
const OpaqueExecutable* wrapped) const {
return const_cast<Executable*>(wrapped);
}

absl::StatusOr<std::unique_ptr<Executable>> HloRunner::ExecutableFromWrapped(
std::unique_ptr<OpaqueExecutable> wrapped) const {
return std::unique_ptr<Executable>(wrapped.release());
}

std::unique_ptr<OpaqueExecutable> HloRunner::WrapExecutable(
std::unique_ptr<Executable> executable) const {
return std::unique_ptr<OpaqueExecutable>(executable.release());
}

absl::StatusOr<absl::Nonnull<const HloModule*>> HloRunner::HloModuleFromWrapped(
const OpaqueExecutable* wrapped) const {
if (wrapped->has_module()) {
return &wrapped->module();
}
return absl::NotFoundError("OpaqueExecutable does not contain an HloModule.");
}

absl::StatusOr<absl::Nonnull<const HloProto*>> HloRunner::HloProtoFromWrapped(
const OpaqueExecutable* wrapped) const {
return wrapped->hlo_proto();
}

} // namespace xla
49 changes: 37 additions & 12 deletions xla/service/hlo_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,29 @@ limitations under the License.
#ifndef XLA_SERVICE_HLO_RUNNER_H_
#define XLA_SERVICE_HLO_RUNNER_H_

#include <map>
#include <cstdint>
#include <functional>
#include <memory>
#include <set>
#include <string>
#include <vector>

#include "absl/base/nullability.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "xla/executable_run_options.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/literal.h"
#include "xla/service/backend.h"
#include "xla/service/compiler.h"
#include "xla/service/computation_layout.h"
#include "xla/service/computation_placer.h"
#include "xla/service/executable.h"
#include "xla/service/hlo_runner_interface.h"
#include "xla/status_macros.h"
#include "xla/service/service_executable_run_options.h"
#include "xla/service/shaped_buffer.h"
#include "xla/stream_executor/device_memory_allocator.h"
#include "xla/stream_executor/stream_executor.h"
#include "xla/types.h"
#include "xla/stream_executor/platform.h"
#include "xla/util.h"
#include "xla/xla_data.pb.h"

Expand Down Expand Up @@ -91,7 +95,7 @@ class HloRunner : public HloRunnerInterface {
using HloRunnerInterface::ExecuteWithExecutable;

absl::StatusOr<Literal> ExecuteWithExecutable(
Executable* executable, absl::Span<const Literal* const> arguments,
OpaqueExecutable* executable, absl::Span<const Literal* const> arguments,
ExecutionProfile* profile) override;

// As Execute(), but accepts and returns device buffers instead of host
Expand All @@ -108,7 +112,8 @@ class HloRunner : public HloRunnerInterface {
bool run_hlo_passes = true, ExecutionProfile* profile = nullptr);

absl::StatusOr<ExecutionOutput> ExecuteWithDeviceBuffers(
Executable* executable, absl::Span<ScopedShapedBuffer const> arguments,
OpaqueExecutable* executable,
absl::Span<ScopedShapedBuffer const> arguments,
ExecutionProfile* profile = nullptr);

// As Execute(), but accepts and returns device buffers instead of host
Expand All @@ -134,10 +139,10 @@ class HloRunner : public HloRunnerInterface {

// Creates an executable object given an HLO module. If run_hlo_passes is
// true, the HLO passes will be run as part of compilation.
absl::StatusOr<std::unique_ptr<Executable>> CreateExecutable(
absl::StatusOr<std::unique_ptr<OpaqueExecutable>> CreateExecutable(
std::unique_ptr<HloModule> module, bool run_hlo_passes) override;

absl::StatusOr<std::unique_ptr<Executable>>
absl::StatusOr<std::unique_ptr<OpaqueExecutable>>
CreateExecutableWithBufferAssignment(
std::unique_ptr<HloModule> module,
const BufferAssignmentProto* /*buffer_assignment_proto*/,
Expand All @@ -162,7 +167,7 @@ class HloRunner : public HloRunnerInterface {
// Note that this call ignores ReplicatedExecutionOptions::run_hlo_passes,
// since we've already compiled the Executable.
absl::StatusOr<std::vector<Literal>> ExecuteReplicated(
Executable* executable, const ReplicatedExecuteOptions& options,
OpaqueExecutable* executable, const ReplicatedExecuteOptions& options,
DeviceAssignment* device_assignment, ExecutionProfile* profile = nullptr);

// Same as above, but with different reusable Executables. This may update the
Expand All @@ -171,7 +176,7 @@ class HloRunner : public HloRunnerInterface {
// Note that this call ignores ReplicatedExecutionOptions::run_hlo_passes,
// since we've already compiled the Executable.
absl::StatusOr<std::vector<Literal>> ExecuteReplicated(
std::function<Executable*(int64_t)> executable_provider,
std::function<OpaqueExecutable*(int64_t)> executable_provider,
std::function<int64_t(int64_t)> argument_count_provider,
std::function<const Literal*(int64_t, int64_t)> argument_provider,
const ReplicatedExecuteOptions& options,
Expand Down Expand Up @@ -199,6 +204,26 @@ class HloRunner : public HloRunnerInterface {

bool HasProperty(HloRunnerPropertyTag::Type tag) const override;

// Helpers to interact with OpaqueExecutable before all users are migrated.
absl::StatusOr<Executable*> ExecutableFromWrapped(
const OpaqueExecutable* wrapped) const;
absl::StatusOr<std::unique_ptr<Executable>> ExecutableFromWrapped(
std::unique_ptr<OpaqueExecutable> wrapped) const;
std::unique_ptr<OpaqueExecutable> WrapExecutable(
std::unique_ptr<Executable> executable) const;
absl::StatusOr<absl::Nonnull<const HloModule*>> HloModuleFromWrapped(
const OpaqueExecutable* wrapped) const override;
// Returns the HloProto of the Executable wrapped by the given
// OpaqueExecutable. This is a temporary API to help move to OpaqueExecutable.
// We need to come up with a better way to obtain this information and
// evaluate whether we need to do this at all. A drop-in migration to
// HloRunnerPjRt (via HloRunnerInterface) won't be possible because this
// information is not available from a PjRt(Loaded)Executable.
//
// TODO: b/393183864 - Remove this API.
absl::StatusOr<absl::Nonnull<const HloProto*>> HloProtoFromWrapped(
const OpaqueExecutable* wrapped) const;

private:
absl::StatusOr<ExecutionOutput> ExecuteWithExecutionInputs(
Executable* executable, std::vector<ExecutionInput> arguments,
Expand Down
4 changes: 1 addition & 3 deletions xla/service/hlo_runner_interface.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,10 @@ limitations under the License.
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/parser/hlo_parser.h"
#include "xla/literal.h"
#include "xla/service/executable.h"
#include "xla/service/hlo_module_config.h"
#include "xla/tsl/platform/env.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/statusor.h"
#include "tsl/platform/statusor.h"

namespace xla {

Expand Down Expand Up @@ -130,7 +128,7 @@ absl::StatusOr<Literal> HloRunnerInterface::ExecuteWithBufferAssignment(
}

absl::StatusOr<Literal> HloRunnerInterface::ExecuteWithExecutable(
Executable* executable, absl::Span<const Literal> arguments,
OpaqueExecutable* executable, absl::Span<const Literal> arguments,
ExecutionProfile* profile) {
// Construct a vector of plain pointers for the arguments.
auto argument_pointers = MakePointerVector<const Literal>(arguments);
Expand Down
Loading

0 comments on commit 264b591

Please sign in to comment.