From 15f2f3e81cc81f2f2dbfe464b0da91f734417ecc Mon Sep 17 00:00:00 2001 From: Alexandros Theodoridis Date: Thu, 31 Oct 2024 12:17:58 +0000 Subject: [PATCH 1/2] Fix cholesky --- xla/service/gpu/ir_emitter_unnested.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/xla/service/gpu/ir_emitter_unnested.cc b/xla/service/gpu/ir_emitter_unnested.cc index deddceeeba600..50c895c40cf74 100644 --- a/xla/service/gpu/ir_emitter_unnested.cc +++ b/xla/service/gpu/ir_emitter_unnested.cc @@ -2898,6 +2898,9 @@ absl::Status IrEmitterUnnested::EmitHloInstruction( if (instr->custom_call_target() == kNopCustomCallTarget) { return absl::OkStatus(); } + if (IsCustomCallToCusolver(*instr)) { + return EmitCholeskyThunk(instr); + } return EmitCustomCallThunk(custom_call); } case HloOpcode::kFusion: { From fb1d171736130d093ef20bbe30f1a175875f4e2a Mon Sep 17 00:00:00 2001 From: Alexandros Theodoridis Date: Mon, 4 Nov 2024 15:12:48 +0000 Subject: [PATCH 2/2] Apply different patch --- xla/service/gpu/ir_emitter_unnested.cc | 27 +++++++++++++------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/xla/service/gpu/ir_emitter_unnested.cc b/xla/service/gpu/ir_emitter_unnested.cc index 50c895c40cf74..25107f7bc7464 100644 --- a/xla/service/gpu/ir_emitter_unnested.cc +++ b/xla/service/gpu/ir_emitter_unnested.cc @@ -55,17 +55,17 @@ limitations under the License. #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/Linker/Linker.h" -#include "mlir/AsmParser/AsmParser.h" // from @llvm-project -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project +#include "mlir/AsmParser/AsmParser.h" // from @llvm-project +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project #include "mlir/Dialect/MemRef/Transforms/Passes.h" // from @llvm-project -#include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/Parser/Parser.h" // from @llvm-project -#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Parser/Parser.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" // from @llvm-project #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" // from @llvm-project #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" // from @llvm-project @@ -2875,10 +2875,12 @@ absl::Status IrEmitterUnnested::EmitHloInstruction( if (IsCustomCallToDnnConvolution(*instr)) { return EmitConvolutionThunk(custom_call); } -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM if (IsCustomCallToCusolver(*instr)) { return EmitCholeskyThunk(instr); } +#endif +#if GOOGLE_CUDA if (IsTriangularSolve(*instr)) { return EmitTriangularSolveCustomCall(instr); } @@ -2898,9 +2900,6 @@ absl::Status IrEmitterUnnested::EmitHloInstruction( if (instr->custom_call_target() == kNopCustomCallTarget) { return absl::OkStatus(); } - if (IsCustomCallToCusolver(*instr)) { - return EmitCholeskyThunk(instr); - } return EmitCustomCallThunk(custom_call); } case HloOpcode::kFusion: {