diff --git a/third_party/tsl/tsl/platform/default/rocm_rocdl_path.cc b/third_party/tsl/tsl/platform/default/rocm_rocdl_path.cc index a1f3eba243afa..c5373b9fb2e85 100644 --- a/third_party/tsl/tsl/platform/default/rocm_rocdl_path.cc +++ b/third_party/tsl/tsl/platform/default/rocm_rocdl_path.cc @@ -40,6 +40,13 @@ string RocmRoot() { #endif } -string RocdlRoot() { return io::JoinPath(RocmRoot(), "amdgcn/bitcode"); } +string RocdlRoot() { + if (const char* device_lib_path_env = std::getenv("HIP_DEVICE_LIB_PATH")) { + return device_lib_path_env; + } + else{ + return io::JoinPath(RocmRoot(), "amdgcn/bitcode"); + } +} } // namespace tsl diff --git a/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc index a3ed5f71fbbb7..4d39208f998a0 100644 --- a/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc +++ b/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc @@ -854,7 +854,13 @@ absl::StatusOr> EmitModuleToHsaco( ir_fs->flush(); } // Locate lld. - std::string lld_path = tsl::io::JoinPath(tsl::RocmRoot(), "llvm/bin"); + std::string lld_path; + if (std::getenv("LLVM_PATH")) { + lld_path = tsl::io::JoinPath(std::getenv("LLVM_PATH"), "bin"); + } + else{ + lld_path = tsl::io::JoinPath(tsl::RocmRoot(), "llvm/bin"); + } auto lld_program = llvm::sys::findProgramByName("ld.lld", {lld_path}); if (!lld_program) { return xla::Internal("unable to find ld.lld in PATH: %s",