Skip to content

Commit

Permalink
[XLA:CPU] Integrating ObjectLoader into JITCompiler
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 711737132
  • Loading branch information
Google-ML-Automation committed Jan 10, 2025
1 parent f615fcf commit 255f677
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 102 deletions.
1 change: 1 addition & 0 deletions xla/backends/cpu/codegen/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ cc_library(
":contiguous_section_memory_manager",
":cpu_features",
":ir_compiler",
":object_loader",
"//xla:util",
"//xla/backends/cpu/runtime:function_library",
"//xla/service/cpu:orc_jit_memory_mapper",
Expand Down
119 changes: 25 additions & 94 deletions xla/backends/cpu/codegen/jit_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ limitations under the License.

#include "xla/backends/cpu/codegen/jit_compiler.h"

#include <algorithm>
#include <cstddef>
#include <memory>
#include <optional>
Expand All @@ -24,11 +23,9 @@ limitations under the License.

#include "absl/base/call_once.h"
#include "absl/base/thread_annotations.h"
#include "absl/container/flat_hash_map.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
Expand All @@ -39,8 +36,6 @@ limitations under the License.
#include "llvm/ExecutionEngine/Orc/ExecutorProcessControl.h"
#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h"
#include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h"
#include "llvm/ExecutionEngine/Orc/Shared/ExecutorAddress.h"
#include "llvm/ExecutionEngine/Orc/Shared/ExecutorSymbolDef.h"
#include "llvm/ExecutionEngine/Orc/TaskDispatch.h"
#include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h"
#include "llvm/IR/Mangler.h"
Expand All @@ -50,10 +45,10 @@ limitations under the License.
#include "llvm/Target/TargetMachine.h"
#include "llvm/Target/TargetOptions.h"
#include "llvm/TargetParser/Host.h"
#include "xla/backends/cpu/codegen/compiled_function_library.h"
#include "xla/backends/cpu/codegen/contiguous_section_memory_manager.h"
#include "xla/backends/cpu/codegen/cpu_features.h"
#include "xla/backends/cpu/codegen/ir_compiler.h"
#include "xla/backends/cpu/codegen/object_loader.h"
#include "xla/backends/cpu/runtime/function_library.h"
#include "xla/service/cpu/orc_jit_memory_mapper.h"
#include "xla/util.h"
Expand Down Expand Up @@ -180,56 +175,50 @@ JitCompiler::JitCompiler(
: target_machine_builder_(std::move(target_machine_builder)),
target_machine_(std::move(target_machine)),
task_dispatcher_(task_dispatcher),
execution_session_(std::move(execution_session)),
object_layer_(CreateObjectLinkingLayer(*execution_session_)),
compile_layer_(CreateCompileLayer(*execution_session_, *object_layer_,
object_loader_(std::make_unique<ObjectLoader>(
std::move(execution_session), num_dylibs)),
compile_layer_(CreateCompileLayer(*object_loader_->execution_session(),
*object_loader_->object_layer(),
std::move(ir_compiler))),
gdb_(llvm::JITEventListener::createGDBRegistrationListener()),
perf_(llvm::JITEventListener::createPerfJITEventListener()) {
// Create at least one dynamic library for the given jit compiler.
dylibs_.resize(std::max<size_t>(1, num_dylibs));
for (size_t i = 0; i < dylibs_.size(); ++i) {
dylibs_[i] = &execution_session_->createBareJITDylib(
absl::StrCat("<xla_jit_dylib_", i, ">"));
if (definition_generator) {
dylibs_[i]->addGenerator(definition_generator(target_machine_.get()));
// TODO(basioli) the definition generator should be passed to the object
// loader, not the jit compiler, this is currently not done to avoid
// dependency on the target machine.
// This exists so that tests can pass.
if (definition_generator) {
for (size_t i = 0; i < object_loader_->num_dylibs(); ++i) {
object_loader_->dylib(i).value()->addGenerator(
definition_generator(target_machine_.get()));
}
}

// Register GDB and perf event listeners with the object linking layer.
if (gdb_) object_layer_->registerJITEventListener(*gdb_);
if (perf_) object_layer_->registerJITEventListener(*perf_);
if (gdb_) object_loader_->object_layer()->registerJITEventListener(*gdb_);
if (perf_) object_loader_->object_layer()->registerJITEventListener(*perf_);

// Copied from LLJIT, required to find symbols on Windows.
if (target_machine_->getTargetTriple().isOSBinFormatCOFF()) {
object_layer_->setOverrideObjectFlagsWithResponsibilityFlags(true);
object_layer_->setAutoClaimResponsibilityForObjectSymbols(true);
object_loader_->object_layer()
->setOverrideObjectFlagsWithResponsibilityFlags(true);
object_loader_->object_layer()->setAutoClaimResponsibilityForObjectSymbols(
true);
}
}

JitCompiler::~JitCompiler() {
if (execution_session_) {
if (auto err = execution_session_->endSession()) {
execution_session_->reportError(std::move(err));
}
}
}
JitCompiler::~JitCompiler() = default;

absl::Status JitCompiler::AddModule(llvm::orc::ThreadSafeModule module,
size_t dylib_index) {
if (dylib_index >= dylibs_.size()) {
return Internal("Invalid dylib index %d (num dylibs: %d))", dylib_index,
dylibs_.size());
}

// Set up module for codegen for the target machine at hand.
module.withModuleDo([&](llvm::Module& m) {
m.setDataLayout(target_machine_->createDataLayout());
m.setTargetTriple(target_machine_->getTargetTriple().getTriple());
});

// Add module to the selected dynamic library.
llvm::orc::JITDylib* dylib = dylibs_[dylib_index];
TF_ASSIGN_OR_RETURN(llvm::orc::JITDylib * dylib,
object_loader_->dylib(dylib_index));
if (auto err = compile_layer_->add(*dylib, std::move(module))) {
return Internal("Failed to add module to dylib %d: %s", dylib_index,
llvm::toString(std::move(err)));
Expand All @@ -240,18 +229,7 @@ absl::Status JitCompiler::AddModule(llvm::orc::ThreadSafeModule module,

absl::Status JitCompiler::AddObjFile(
std::unique_ptr<llvm::MemoryBuffer> obj_file, size_t dylib_index) {
if (dylib_index >= dylibs_.size()) {
return Internal("Invalid dylib index %d (num dylibs: %d))", dylib_index,
dylibs_.size());
}

llvm::orc::JITDylib* dylib = dylibs_[dylib_index];
if (auto err = object_layer_->add(*dylib, std::move(obj_file))) {
return Internal("Failed to add object file to dylib %d: %s", dylib_index,
llvm::toString(std::move(err)));
}

return absl::OkStatus();
return object_loader_->AddObjFile(std::move(obj_file), dylib_index);
}

absl::StatusOr<std::unique_ptr<FunctionLibrary>> JitCompiler::Compile(
Expand All @@ -260,55 +238,8 @@ absl::StatusOr<std::unique_ptr<FunctionLibrary>> JitCompiler::Compile(
return TraceMeEncode("JitCompiler::Compile",
{{"num_symbols", symbols.size()}});
});

// Mangle symbol names for the target machine data layout.
llvm::DataLayout data_layout = target_machine_->createDataLayout();
auto mangle = [&](absl::string_view name) {
llvm::SmallVector<char, 40> mangled;
llvm::Mangler::getNameWithPrefix(mangled, name, data_layout);
return std::string(mangled.begin(), mangled.end());
};

// Build a symbol lookup set.
llvm::orc::SymbolLookupSet lookup_set;
for (const auto& symbol : symbols) {
VLOG(5) << absl::StreamFormat(" - look up symbol: %s", symbol.name);
lookup_set.add(execution_session_->intern(mangle(symbol.name)));
}

// Build a search order for the dynamic libraries.
llvm::orc::JITDylibSearchOrder search_order(dylibs_.size());
for (size_t i = 0; i < dylibs_.size(); ++i) {
search_order[i] = std::make_pair(
dylibs_[i], llvm::orc::JITDylibLookupFlags::MatchExportedSymbolsOnly);
}

// Look up all requested symbols in the execution session.
auto symbol_map = execution_session_->lookup(std::move(search_order),
std::move(lookup_set));

// Wait for all compilation tasks to finish.
task_dispatcher_->shutdown();

if (auto err = symbol_map.takeError()) {
return Internal("%s", llvm::toString(std::move(err)));
}

// Resolve type-erased symbol pointers from the symbol map.
using ResolvedSymbol = CompiledFunctionLibrary::ResolvedSymbol;
absl::flat_hash_map<std::string, ResolvedSymbol> resolved_map;

for (const auto& symbol : symbols) {
auto symbol_name = execution_session_->intern(mangle(symbol.name));
llvm::orc::ExecutorSymbolDef symbol_def = symbol_map->at(symbol_name);
llvm::orc::ExecutorAddr symbol_addr = symbol_def.getAddress();
void* ptr = reinterpret_cast<void*>(symbol_addr.getValue());
resolved_map[symbol.name] = ResolvedSymbol{symbol.type_id, ptr};
}

return std::make_unique<CompiledFunctionLibrary>(
std::move(execution_session_), std::move(object_layer_),
std::move(resolved_map));
return std::move(*object_loader_)
.Load(std::move(symbols), target_machine_->createDataLayout());
}

JitCompiler::TaskDispatcher::TaskDispatcher(TaskRunner task_runner)
Expand Down
7 changes: 2 additions & 5 deletions xla/backends/cpu/codegen/jit_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ limitations under the License.
#include "llvm/Target/TargetMachine.h"
#include "llvm/Target/TargetOptions.h"
#include "xla/backends/cpu/codegen/ir_compiler.h"
#include "xla/backends/cpu/codegen/object_loader.h"
#include "xla/backends/cpu/runtime/function_library.h"
#include "tsl/platform/cpu_info.h"

Expand Down Expand Up @@ -173,13 +174,9 @@ class JitCompiler {

TaskDispatcher* task_dispatcher_; // owned by `execution_session_`

std::unique_ptr<llvm::orc::ExecutionSession> execution_session_;
std::unique_ptr<llvm::orc::RTDyldObjectLinkingLayer> object_layer_;
std::unique_ptr<ObjectLoader> object_loader_;
std::unique_ptr<llvm::orc::IRCompileLayer> compile_layer_;

// Non-owning pointers to dynamic libraries created for the execution session.
std::vector<llvm::orc::JITDylib*> dylibs_;

// Non owning pointer to JIT event listeners for gdb and perf.
llvm::JITEventListener* gdb_; // not owned
llvm::JITEventListener* perf_; // not owned
Expand Down
28 changes: 26 additions & 2 deletions xla/backends/cpu/codegen/object_loader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,25 @@ ObjectLoader::ObjectLoader(size_t num_dylibs)
object_layer_ = CreateObjectLinkingLayer(*execution_session_);
}

ObjectLoader::ObjectLoader(
std::unique_ptr<llvm::orc::ExecutionSession> execution_session,
size_t num_dylibs)
: execution_session_(std::move(execution_session)),
object_layer_(CreateObjectLinkingLayer(*execution_session_)) {
// TODO(basioli) avoid code duplication with the other constructor.
// Create at least one dynamic library for the given jit compiler.
dylibs_.resize(std::max<size_t>(1, num_dylibs));
for (size_t i = 0; i < dylibs_.size(); ++i) {
dylibs_[i] = &execution_session_->createBareJITDylib(
absl::StrCat("<xla_jit_dylib_", i, ">"));
// TODO using target machine might bring some deps we don't need.
// as a first attempt fully remove it, consider pruning the reqs
// if (definition_generator) {
// dylibs_[i]->addGenerator(definition_generator(target_machine_.get()));
// }
}
}

absl::Status ObjectLoader::AddObjFile(const std::string& obj_file,
const std::string& memory_buffer_name,
size_t dylib_index) {
Expand All @@ -98,13 +117,18 @@ absl::Status ObjectLoader::AddObjFile(const std::string& obj_file,
auto obj_file_mem_buffer =
llvm::MemoryBuffer::getMemBuffer(data, memory_buffer_name);

if (!obj_file_mem_buffer) {
return AddObjFile(std::move(obj_file_mem_buffer), dylib_index);
}

absl::Status ObjectLoader::AddObjFile(
std::unique_ptr<llvm::MemoryBuffer> obj_file, size_t dylib_index) {
if (!obj_file) {
return absl::Status(absl::StatusCode::kInvalidArgument,
"Failed to create memory buffer");
}

llvm::orc::JITDylib* dylib = dylibs_[dylib_index];
if (auto err = object_layer_->add(*dylib, std::move(obj_file_mem_buffer))) {
if (auto err = object_layer_->add(*dylib, std::move(obj_file))) {
return absl::Status(
absl::StatusCode::kInvalidArgument,
absl::StrFormat("Failed to add object file to dylib %d: %s",
Expand Down
13 changes: 12 additions & 1 deletion xla/backends/cpu/codegen/object_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,29 +21,40 @@ limitations under the License.
#include <string>
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "llvm/ExecutionEngine/Orc/Core.h"
#include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/Support/MemoryBuffer.h"
#include "xla/backends/cpu/codegen/compiled_function_library.h"
#include "xla/backends/cpu/runtime/function_library.h"

namespace xla::cpu {

class ObjectLoader {
public:
using Symbol = FunctionLibrary::Symbol;
using ResolvedSymbol = CompiledFunctionLibrary::ResolvedSymbol;

explicit ObjectLoader(size_t num_dylibs);
ObjectLoader(std::unique_ptr<llvm::orc::ExecutionSession> execution_session,
size_t num_dylibs);

absl::Status AddObjFile(const std::string& obj_file,
const std::string& memory_buffer_name,
size_t dylib_index = 0);

absl::Status AddObjFile(std::unique_ptr<llvm::MemoryBuffer> obj_file,
size_t dylib_index);

absl::StatusOr<std::unique_ptr<FunctionLibrary>> Load(
absl::Span<const Symbol> symbols, const llvm::DataLayout& data_layout) &&;

size_t num_dylibs() const { return dylibs_.size(); }

llvm::orc::RTDyldObjectLinkingLayer* object_layer() {
return object_layer_.get();
}
Expand All @@ -65,8 +76,8 @@ class ObjectLoader {
~ObjectLoader();

private:
std::unique_ptr<llvm::orc::RTDyldObjectLinkingLayer> object_layer_;
std::unique_ptr<llvm::orc::ExecutionSession> execution_session_;
std::unique_ptr<llvm::orc::RTDyldObjectLinkingLayer> object_layer_;

// Non-owning pointers to dynamic libraries created for the execution session.
std::vector<llvm::orc::JITDylib*> dylibs_;
Expand Down

0 comments on commit 255f677

Please sign in to comment.