diff --git a/third_party/tsl/tsl/platform/default/rocm_rocdl_path.cc b/third_party/tsl/tsl/platform/default/rocm_rocdl_path.cc index a1f3eba243afa4..a1934f81e35723 100644 --- a/third_party/tsl/tsl/platform/default/rocm_rocdl_path.cc +++ b/third_party/tsl/tsl/platform/default/rocm_rocdl_path.cc @@ -26,7 +26,7 @@ limitations under the License. namespace tsl { -string RocmRoot() { +std::string RocmRoot() { #if TENSORFLOW_USE_ROCM if (const char* rocm_path_env = std::getenv("ROCM_PATH")) { VLOG(3) << "ROCM root = " << rocm_path_env; @@ -40,6 +40,12 @@ string RocmRoot() { #endif } -string RocdlRoot() { return io::JoinPath(RocmRoot(), "amdgcn/bitcode"); } +std::string RocdlRoot() { + if (const char* device_lib_path_env = std::getenv("HIP_DEVICE_LIB_PATH")) { + return device_lib_path_env; + } else { + return io::JoinPath(RocmRoot(), "amdgcn/bitcode"); + } +} } // namespace tsl diff --git a/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc index 13e17bbb477f24..721cae6ac3c269 100644 --- a/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc +++ b/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -855,7 +856,12 @@ absl::StatusOr> EmitModuleToHsaco( ir_fs->flush(); } // Locate lld. - std::string lld_path = tsl::io::JoinPath(tsl::RocmRoot(), "llvm/bin"); + std::string lld_path; + if (std::getenv("LLVM_PATH")) { + lld_path = tsl::io::JoinPath(std::getenv("LLVM_PATH"), "bin"); + } else { + lld_path = tsl::io::JoinPath(tsl::RocmRoot(), "llvm/bin"); + } auto lld_program = llvm::sys::findProgramByName("ld.lld", {lld_path}); if (!lld_program) { return xla::Internal("unable to find ld.lld in PATH: %s",