From 7d7b291874f4ab48729d3e3c4e2b55406f2a7673 Mon Sep 17 00:00:00 2001 From: xla authors Date: Tue, 17 Dec 2024 06:17:55 -0800 Subject: [PATCH] [XLA:CPU] Decouple object loading from JIT compiler. PiperOrigin-RevId: 707069250 --- xla/backends/cpu/codegen/BUILD | 66 ++++++- .../contiguous_section_memory_manager.cc | 3 +- xla/backends/cpu/codegen/jit_compiler.cc | 2 +- xla/backends/cpu/codegen/object_loader.cc | 174 ++++++++++++++++++ xla/backends/cpu/codegen/object_loader.h | 79 ++++++++ .../cpu/codegen/object_loader_test.cc | 162 ++++++++++++++++ xla/backends/cpu/runtime/BUILD | 1 - 7 files changed, 480 insertions(+), 7 deletions(-) create mode 100644 xla/backends/cpu/codegen/object_loader.cc create mode 100644 xla/backends/cpu/codegen/object_loader.h create mode 100644 xla/backends/cpu/codegen/object_loader_test.cc diff --git a/xla/backends/cpu/codegen/BUILD b/xla/backends/cpu/codegen/BUILD index 22ea6cb0cb1311..9f442ae560686c 100644 --- a/xla/backends/cpu/codegen/BUILD +++ b/xla/backends/cpu/codegen/BUILD @@ -26,12 +26,13 @@ cc_library( srcs = ["contiguous_section_memory_manager.cc"], hdrs = ["contiguous_section_memory_manager.h"], deps = [ - "//xla:util", - "@llvm-project//llvm:Core", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@llvm-project//llvm:ExecutionEngine", - "@llvm-project//llvm:OrcJIT", "@llvm-project//llvm:Support", - "@tsl//tsl/platform:logging", + # TODO(basioli): This dependency increases the binary size significantly. + # Consider reducing the dependency size, or use something alternative. + "//xla:util", ], ) @@ -93,6 +94,7 @@ cc_library( "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/log", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -353,3 +355,59 @@ cc_library( "@llvm-project//llvm:OrcJIT", ], ) + +cc_library( + name = "object_loader", + srcs = ["object_loader.cc"], + hdrs = ["object_loader.h"], + deps = [ + ":compiled_function_library", + ":contiguous_section_memory_manager", + "//xla/backends/cpu/runtime:function_library", + "//xla/service/cpu:orc_jit_memory_mapper", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:JITLink", + "@llvm-project//llvm:OrcShared", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:ir_headers", + ], +) + +xla_cc_test( + name = "object_loader_test", + srcs = ["object_loader_test.cc"], + deps = [ + ":ir_compiler", + ":jit_compiler", + ":object_loader", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/backends/cpu/runtime:function_library", + "//xla/service:cpu_plugin", + "//xla/service/cpu:executable_proto_cc", + "//xla/service/llvm_ir:llvm_util", + "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + "@llvm-project//llvm:AsmParser", + "@llvm-project//llvm:JITLink", + "@llvm-project//llvm:Object", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:Target", + "@llvm-project//llvm:ir_headers", + "@tsl//tsl/platform:statusor", + ], +) diff --git a/xla/backends/cpu/codegen/contiguous_section_memory_manager.cc b/xla/backends/cpu/codegen/contiguous_section_memory_manager.cc index f30fa63be52ad9..ae15857de011c1 100644 --- a/xla/backends/cpu/codegen/contiguous_section_memory_manager.cc +++ b/xla/backends/cpu/codegen/contiguous_section_memory_manager.cc @@ -20,12 +20,13 @@ limitations under the License. #include #include // NOLINT +#include "absl/log/check.h" +#include "absl/log/log.h" #include "llvm/ExecutionEngine/SectionMemoryManager.h" #include "llvm/Support/Alignment.h" #include "llvm/Support/Memory.h" #include "llvm/Support/Process.h" #include "xla/util.h" -#include "tsl/platform/logging.h" namespace xla::cpu { namespace { diff --git a/xla/backends/cpu/codegen/jit_compiler.cc b/xla/backends/cpu/codegen/jit_compiler.cc index 7f3acba32e57d5..e91e89a0007ff1 100644 --- a/xla/backends/cpu/codegen/jit_compiler.cc +++ b/xla/backends/cpu/codegen/jit_compiler.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/base/call_once.h" #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" @@ -57,7 +58,6 @@ limitations under the License. #include "xla/service/cpu/orc_jit_memory_mapper.h" #include "xla/util.h" #include "tsl/platform/cpu_info.h" -#include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" #include "tsl/profiler/lib/traceme.h" #include "tsl/profiler/lib/traceme_encode.h" diff --git a/xla/backends/cpu/codegen/object_loader.cc b/xla/backends/cpu/codegen/object_loader.cc new file mode 100644 index 00000000000000..ca70110d1e188f --- /dev/null +++ b/xla/backends/cpu/codegen/object_loader.cc @@ -0,0 +1,174 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/codegen/object_loader.h" + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ExecutionEngine/Orc/Core.h" +#include "llvm/ExecutionEngine/Orc/ExecutorProcessControl.h" +#include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h" +#include "llvm/ExecutionEngine/Orc/Shared/ExecutorAddress.h" +#include "llvm/ExecutionEngine/Orc/Shared/ExecutorSymbolDef.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Mangler.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/MemoryBuffer.h" +#include "xla/backends/cpu/codegen/compiled_function_library.h" +#include "xla/backends/cpu/codegen/contiguous_section_memory_manager.h" +#include "xla/backends/cpu/runtime/function_library.h" +#include "xla/service/cpu/orc_jit_memory_mapper.h" + +namespace xla::cpu { + +static std::unique_ptr +CreateObjectLinkingLayer(llvm::orc::ExecutionSession& execution_session) { + return std::make_unique( + execution_session, [] { + return std::make_unique( + orc_jit_memory_mapper::GetInstance()); + }); +} + +ObjectLoader::ObjectLoader(size_t num_dylibs) +/*: target_machine_(std::move(target_machine))*/ { + // LLVM execution session that holds jit-compiled functions. + execution_session_ = std::make_unique( + std::make_unique( + /*SSP=*/nullptr, /*D=*/nullptr)); + + execution_session_->setErrorReporter([](llvm::Error err) { + LOG(ERROR) << "LLVM compilation error: " << llvm::toString(std::move(err)); + }); + + // Create at least one dynamic library for the given jit compiler. + dylibs_.resize(std::max(1, num_dylibs)); + for (size_t i = 0; i < dylibs_.size(); ++i) { + dylibs_[i] = &execution_session_->createBareJITDylib( + absl::StrCat("")); + // TODO using target machine might bring some deps we don't need. + // as a first attempt fully remove it, consider pruning the reqs + // if (definition_generator) { + // dylibs_[i]->addGenerator(definition_generator(target_machine_.get())); + // } + } + + object_layer_ = CreateObjectLinkingLayer(*execution_session_); +} + +absl::Status ObjectLoader::AddObjFile(const std::string& obj_file, + const std::string& memory_buffer_name, + size_t dylib_index) { + if (dylib_index >= dylibs_.size()) { + return absl::Status( + absl::StatusCode::kInvalidArgument, + absl::StrFormat("Invalid dylib index %d (num dylibs: %d))", dylib_index, + dylibs_.size())); + } + + llvm::StringRef data(obj_file.data(), obj_file.size()); + + auto obj_file_mem_buffer = + llvm::MemoryBuffer::getMemBuffer(data, memory_buffer_name); + + if (!obj_file_mem_buffer) { + return absl::Status(absl::StatusCode::kInvalidArgument, + "Failed to create memory buffer"); + } + + llvm::orc::JITDylib* dylib = dylibs_[dylib_index]; + if (auto err = object_layer_->add(*dylib, std::move(obj_file_mem_buffer))) { + return absl::Status( + absl::StatusCode::kInvalidArgument, + absl::StrFormat("Failed to add object file to dylib %d: %s", + dylib_index, llvm::toString(std::move(err)))); + } + + return absl::OkStatus(); +} + +absl::StatusOr> ObjectLoader::Load( + absl::Span symbols, const llvm::DataLayout& data_layout) && { + // Mangle symbol names for the target machine data layout. + auto mangle = [&](absl::string_view name) { + llvm::SmallVector mangled; + llvm::Mangler::getNameWithPrefix(mangled, name, data_layout); + return std::string(mangled.begin(), mangled.end()); + }; + + // Build a symbol lookup set. + llvm::orc::SymbolLookupSet lookup_set; + for (const auto& symbol : symbols) { + VLOG(5) << absl::StreamFormat(" - look up symbol: %s", symbol.name); + lookup_set.add(execution_session_->intern(mangle(symbol.name))); + } + + // Build a search order for the dynamic libraries. + llvm::orc::JITDylibSearchOrder search_order(dylibs_.size()); + for (size_t i = 0; i < dylibs_.size(); ++i) { + search_order[i] = std::make_pair( + dylibs_[i], llvm::orc::JITDylibLookupFlags::MatchExportedSymbolsOnly); + } + + // Look up all requested symbols in the execution session. + auto symbol_map = execution_session_->lookup(std::move(search_order), + std::move(lookup_set)); + + if (auto err = symbol_map.takeError()) { + return absl::Status(absl::StatusCode::kInternal, + absl::StrFormat("%s", llvm::toString(std::move(err)))); + } + + // Resolve type-erased symbol pointers from the symbol map. + using ResolvedSymbol = CompiledFunctionLibrary::ResolvedSymbol; + absl::flat_hash_map resolved_map; + + for (const auto& symbol : symbols) { + auto symbol_name = execution_session_->intern(mangle(symbol.name)); + llvm::orc::ExecutorSymbolDef symbol_def = symbol_map->at(symbol_name); + llvm::orc::ExecutorAddr symbol_addr = symbol_def.getAddress(); + void* ptr = reinterpret_cast(symbol_addr.getValue()); + resolved_map[symbol.name] = ResolvedSymbol{symbol.type_id, ptr}; + } + + return std::make_unique( + std::move(execution_session_), std::move(object_layer_), + std::move(resolved_map)); +} + +ObjectLoader::~ObjectLoader() { + if (execution_session_) { + if (auto err = execution_session_->endSession()) { + execution_session_->reportError(std::move(err)); + } + } +} + +} // namespace xla::cpu diff --git a/xla/backends/cpu/codegen/object_loader.h b/xla/backends/cpu/codegen/object_loader.h new file mode 100644 index 00000000000000..00739eca9f9bf6 --- /dev/null +++ b/xla/backends/cpu/codegen/object_loader.h @@ -0,0 +1,79 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_CPU_CODEGEN_OBJECT_LOADER_H_ +#define XLA_BACKENDS_CPU_CODEGEN_OBJECT_LOADER_H_ + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "llvm/ExecutionEngine/Orc/Core.h" +#include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h" +#include "llvm/IR/DataLayout.h" +#include "xla/backends/cpu/runtime/function_library.h" + +namespace xla::cpu { + +class ObjectLoader { + public: + using Symbol = FunctionLibrary::Symbol; + + explicit ObjectLoader(size_t num_dylibs); + + absl::Status AddObjFile(const std::string& obj_file, + const std::string& memory_buffer_name, + size_t dylib_index = 0); + + absl::StatusOr> Load( + absl::Span symbols, const llvm::DataLayout& data_layout) &&; + + llvm::orc::RTDyldObjectLinkingLayer* object_layer() { + return object_layer_.get(); + } + + llvm::orc::ExecutionSession* execution_session() { + return execution_session_.get(); + } + + absl::StatusOr dylib(size_t dylib_index) { + if (dylib_index >= dylibs_.size()) { + return absl::Status( + absl::StatusCode::kInvalidArgument, + absl::StrFormat("Invalid dylib index %d (num dylibs: %d))", + dylib_index, dylibs_.size())); + } + return dylibs_[dylib_index]; + } + + ~ObjectLoader(); + + private: + std::unique_ptr object_layer_; + std::unique_ptr execution_session_; + + // Non-owning pointers to dynamic libraries created for the execution session. + std::vector dylibs_; + + // std::shared_ptr target_machine_; +}; + +} // namespace xla::cpu + +#endif // XLA_BACKENDS_CPU_CODEGEN_OBJECT_LOADER_H_ diff --git a/xla/backends/cpu/codegen/object_loader_test.cc b/xla/backends/cpu/codegen/object_loader_test.cc new file mode 100644 index 00000000000000..bb9cbe13e18082 --- /dev/null +++ b/xla/backends/cpu/codegen/object_loader_test.cc @@ -0,0 +1,162 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/codegen/object_loader.h" + +#include +#include +#include +#include +#include +#include + +#include +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/AsmParser/Parser.h" +#include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/Object/ObjectFile.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Target/TargetMachine.h" +#include "llvm/Target/TargetOptions.h" +#include "xla/backends/cpu/codegen/ir_compiler.h" +#include "xla/backends/cpu/codegen/jit_compiler.h" +#include "xla/backends/cpu/runtime/function_library.h" +#include "xla/service/cpu/executable.pb.h" +#include "xla/service/llvm_ir/llvm_util.h" +#include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/statusor.h" + +namespace xla::cpu { +namespace { + +// Parses the LLVM IR into a ThreadSafeModule. +static absl::StatusOr ParseModule( + llvm::orc::ThreadSafeContext& context, absl::string_view ir, + absl::string_view name) { + llvm::SMDiagnostic diagnostic; + llvm::MemoryBufferRef ir_buffer(ir, name); + + auto m = llvm::parseAssembly(ir_buffer, diagnostic, *context.getContext()); + if (m == nullptr) { + return Internal("Failed to parse LLVM IR: %s", + diagnostic.getMessage().str()); + } + + return llvm::orc::ThreadSafeModule(std::move(m), context); +} + +static absl::StatusOr> Compile( + JitCompiler compiler, absl::Span symbols) { + return std::move(compiler).Compile(symbols); +}; + +TEST(ObjectLoader, Load) { + constexpr size_t kNumDyLibs = 1; + auto context = std::make_unique(); + llvm::orc::ThreadSafeContext tsc(std::move(context)); + + std::vector object_files; + auto object_files_saver = + [&object_files](const llvm::Module& /*module*/, + const llvm::object::ObjectFile& object_file) -> void { + object_files.emplace_back(object_file.getData().data(), + object_file.getData().size()); + }; + + JitCompiler::Options options; + options.num_dylibs = kNumDyLibs; + options.ir_compiler_hooks.post_codegen = object_files_saver; + + TF_ASSERT_OK_AND_ASSIGN( + auto compiler, + JitCompiler::Create(llvm::TargetOptions(), std::move(options))); + + constexpr absl::string_view add_in_place_ir = R"( + define void @AddInplace(ptr %arg) { + %v0 = load float, ptr %arg + %v1 = fadd float %v0, %v0 + store float %v1, ptr %arg + ret void + })"; + + auto add_module = [&](absl::string_view ir, absl::string_view name, + size_t dylib_index) -> absl::Status { + TF_ASSIGN_OR_RETURN(llvm::orc::ThreadSafeModule tsm, + ParseModule(tsc, ir, name)); + TF_RETURN_IF_ERROR(compiler.AddModule(std::move(tsm), dylib_index)); + return absl::OkStatus(); + }; + + TF_ASSERT_OK(add_module(add_in_place_ir, "AddInplace", 0)); + + using ScalarFn = void(float*); + std::vector symbols = { + FunctionLibrary::Sym("AddInplace")}; + + llvm::DataLayout data_layout = compiler.target_machine()->createDataLayout(); + TF_ASSERT_OK_AND_ASSIGN(auto function_library_compiled, + Compile(std::move(compiler), symbols)); + + TF_ASSERT_OK_AND_ASSIGN( + ScalarFn * add_in_place_compiled, + function_library_compiled->ResolveFunction("AddInplace")); + + EXPECT_NE(add_in_place_compiled, nullptr); + + auto object_loader(std::make_unique(/*num_dylibs=*/kNumDyLibs)); + { + size_t obj_file_index = 0; + for (auto& obj_file : object_files) { + llvm::StringRef data(obj_file.data(), obj_file.size()); + TF_ASSERT_OK(object_loader->AddObjFile( + obj_file, absl::StrCat("loaded_obj_file_", obj_file_index++))); + } + } + + TF_ASSERT_OK_AND_ASSIGN(auto loaded_function_library, + std::move(*object_loader).Load(symbols, data_layout)); + + TF_ASSERT_OK_AND_ASSIGN( + ScalarFn * loaded_add_in_place, + loaded_function_library->ResolveFunction("AddInplace")); + + EXPECT_NE(loaded_add_in_place, nullptr); + + constexpr float kInputValue = 1.0f; + constexpr float kExpectedOutput = kInputValue + kInputValue; + + float compiled_function_input = kInputValue; + add_in_place_compiled(&compiled_function_input); + EXPECT_EQ(compiled_function_input, kExpectedOutput); + + float loaded_function_input = 1.0f; + loaded_add_in_place(&loaded_function_input); + EXPECT_EQ(loaded_function_input, compiled_function_input); +} + +} // namespace +} // namespace xla::cpu diff --git a/xla/backends/cpu/runtime/BUILD b/xla/backends/cpu/runtime/BUILD index 9ee9efeae75ac7..c80d349c2ae924 100644 --- a/xla/backends/cpu/runtime/BUILD +++ b/xla/backends/cpu/runtime/BUILD @@ -127,7 +127,6 @@ cc_library( hdrs = ["function_library.h"], deps = [ ":kernel_c_api", - "//xla:util", "//xla/tsl/lib/gtl:int_type", "@com_google_absl//absl/status:statusor", "@tsl//tsl/platform:statusor",