From 1d25524373b57c1adef7440027c5a0d10a443ef3 Mon Sep 17 00:00:00 2001 From: Harsha HS Date: Fri, 21 Jun 2024 13:53:22 -0500 Subject: [PATCH 1/7] Revert "added explicit nullptr" This reverts commit 95465d6f08b4485f9f6165e9324fe5a6059949a0. --- xla/stream_executor/rocm/rocm_blas.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/xla/stream_executor/rocm/rocm_blas.cc b/xla/stream_executor/rocm/rocm_blas.cc index 13ca9ea67d804..6627f3175dfc3 100644 --- a/xla/stream_executor/rocm/rocm_blas.cc +++ b/xla/stream_executor/rocm/rocm_blas.cc @@ -344,9 +344,9 @@ absl::Status ROCMBlas::DoBlasInternalImpl(FuncT rocblas_func, Stream *stream, } } #if TF_ROCM_VERSION >= 60000 - if (auto *workspace = GetWorkspace(); workspace != nullptr && - workspace->opaque() != nullptr && workspace->size() > 0) { - (void)wrap::rocblas_set_workspace(blas_, workspace->opaque(), + if (auto *workspace = GetWorkspace(); + workspace != nullptr && workspace->opaque() && workspace->size() > 0) { + (void)wrap::rocblas_set_workspace(blas_, workspace->opaque(), workspace->size()); } #endif From d0d671eda2a7a8ae5d1ef3c08095712c33c0ba19 Mon Sep 17 00:00:00 2001 From: Harsha HS Date: Fri, 21 Jun 2024 13:53:32 -0500 Subject: [PATCH 2/7] oRevert "beautified rocblas wrapper" This reverts commit 616b18d057ce40bd85443fd8d89b8cb4b687082c. --- xla/stream_executor/rocm/rocblas_wrapper.h | 28 +++++++++++----------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/xla/stream_executor/rocm/rocblas_wrapper.h b/xla/stream_executor/rocm/rocblas_wrapper.h index 50dc21cbc0c87..98bcca08c9bf6 100644 --- a/xla/stream_executor/rocm/rocblas_wrapper.h +++ b/xla/stream_executor/rocm/rocblas_wrapper.h @@ -257,22 +257,22 @@ using stream_executor::internal::CachedDsoLoader::GetRocblasDsoHandle; __macro(rocblas_zgemm_strided_batched) \ __macro(rocblas_gemm_ex) \ __macro(rocblas_gemm_strided_batched_ex) \ - __macro(rocblas_gemm_ex_get_solutions) \ - __macro(rocblas_gemm_ex_get_solutions_by_type) \ - __macro(rocblas_gemm_batched_ex_get_solutions) \ + __macro(rocblas_gemm_ex_get_solutions) \ + __macro(rocblas_gemm_ex_get_solutions_by_type) \ + __macro(rocblas_gemm_batched_ex_get_solutions) \ __macro(rocblas_gemm_batched_ex_get_solutions_by_type) \ __macro(rocblas_gemm_strided_batched_ex_get_solutions) \ - __macro(rocblas_is_managing_device_memory) \ - __macro(rocblas_is_user_managing_device_memory) \ - __macro(rocblas_set_workspace) \ - __macro(rocblas_strsm_batched) \ - __macro(rocblas_dtrsm_batched) \ - __macro(rocblas_ctrsm_batched) \ - __macro(rocblas_ztrsm_batched) \ - __macro(rocblas_create_handle) \ - __macro(rocblas_destroy_handle) \ - __macro(rocblas_get_stream) \ - __macro(rocblas_set_stream) \ + __macro(rocblas_is_managing_device_memory) \ + __macro(rocblas_is_user_managing_device_memory) \ + __macro(rocblas_set_workspace) \ + __macro(rocblas_strsm_batched) \ + __macro(rocblas_dtrsm_batched) \ + __macro(rocblas_ctrsm_batched) \ + __macro(rocblas_ztrsm_batched) \ + __macro(rocblas_create_handle) \ + __macro(rocblas_destroy_handle) \ + __macro(rocblas_get_stream) \ + __macro(rocblas_set_stream) \ __macro(rocblas_set_atomics_mode) // clang-format on From 4b9bb22dfe4da38d385750e800aef372c08f8288 Mon Sep 17 00:00:00 2001 From: Harsha HS Date: Fri, 21 Jun 2024 13:53:47 -0500 Subject: [PATCH 3/7] Revert "addressing reviewer comments" This reverts commit 300594d9e6dca7b25f01ad4fffdcd1317dae1896. --- xla/stream_executor/rocm/rocblas_wrapper.h | 2 +- xla/stream_executor/rocm/rocm_blas.cc | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/xla/stream_executor/rocm/rocblas_wrapper.h b/xla/stream_executor/rocm/rocblas_wrapper.h index 98bcca08c9bf6..8d3fa831f1c1d 100644 --- a/xla/stream_executor/rocm/rocblas_wrapper.h +++ b/xla/stream_executor/rocm/rocblas_wrapper.h @@ -66,7 +66,7 @@ using stream_executor::internal::CachedDsoLoader::GetRocblasDsoHandle; return f; \ } \ template \ - auto operator()(Args... args) { \ + auto operator()(Args... args) { \ return DynLoad()(args...); \ } \ } __name; diff --git a/xla/stream_executor/rocm/rocm_blas.cc b/xla/stream_executor/rocm/rocm_blas.cc index 6627f3175dfc3..b1229df22c047 100644 --- a/xla/stream_executor/rocm/rocm_blas.cc +++ b/xla/stream_executor/rocm/rocm_blas.cc @@ -345,9 +345,8 @@ absl::Status ROCMBlas::DoBlasInternalImpl(FuncT rocblas_func, Stream *stream, } #if TF_ROCM_VERSION >= 60000 if (auto *workspace = GetWorkspace(); - workspace != nullptr && workspace->opaque() && workspace->size() > 0) { - (void)wrap::rocblas_set_workspace(blas_, workspace->opaque(), - workspace->size()); + workspace && workspace->opaque() && workspace->size() > 0) { + (void)wrap::rocblas_set_workspace(blas_, workspace->opaque(), workspace->size()); } #endif From 9f7c1db842a314de3e0f91fe8ddd4f9b2dba4421 Mon Sep 17 00:00:00 2001 From: Harsha HS Date: Fri, 21 Jun 2024 13:53:57 -0500 Subject: [PATCH 4/7] Revert "added memory management functions" This reverts commit 084d266f92257c9bbf96f2eb3e062308c2854dc6. --- xla/service/gpu/matmul_utils.cc | 6 +++--- xla/stream_executor/rocm/rocblas_wrapper.h | 5 +---- xla/stream_executor/rocm/rocm_blas.cc | 6 ------ 3 files changed, 4 insertions(+), 13 deletions(-) diff --git a/xla/service/gpu/matmul_utils.cc b/xla/service/gpu/matmul_utils.cc index 596c36e1f1c6f..8baa272d06e7a 100644 --- a/xla/service/gpu/matmul_utils.cc +++ b/xla/service/gpu/matmul_utils.cc @@ -616,6 +616,9 @@ absl::Status DoGemm(const se::gpu::MatrixDescriptor& lhs, return absl::InternalError("No Blas support for stream"); } + // Set a workspace for all Blas operations launched below. + se::blas::BlasSupport::ScopedWorkspace scoped_workspace(blas, &workspace); + if (algorithm) { return DoGemmWithAlgorithm( lhs, rhs, output, workspace, alpha, beta, stream, precision_algorithm, @@ -623,9 +626,6 @@ absl::Status DoGemm(const se::gpu::MatrixDescriptor& lhs, context); } - // Set a workspace for all Blas operations launched below. - se::blas::BlasSupport::ScopedWorkspace scoped_workspace(blas, &workspace); - if (output.batch_size != 1) { return blas->BlasGemmStridedBatched( stream, lhs.transpose, rhs.transpose, output.m, output.n, output.k, diff --git a/xla/stream_executor/rocm/rocblas_wrapper.h b/xla/stream_executor/rocm/rocblas_wrapper.h index 8d3fa831f1c1d..3d444ab83a0ee 100644 --- a/xla/stream_executor/rocm/rocblas_wrapper.h +++ b/xla/stream_executor/rocm/rocblas_wrapper.h @@ -66,7 +66,7 @@ using stream_executor::internal::CachedDsoLoader::GetRocblasDsoHandle; return f; \ } \ template \ - auto operator()(Args... args) { \ + rocblas_status operator()(Args... args) { \ return DynLoad()(args...); \ } \ } __name; @@ -262,9 +262,6 @@ using stream_executor::internal::CachedDsoLoader::GetRocblasDsoHandle; __macro(rocblas_gemm_batched_ex_get_solutions) \ __macro(rocblas_gemm_batched_ex_get_solutions_by_type) \ __macro(rocblas_gemm_strided_batched_ex_get_solutions) \ - __macro(rocblas_is_managing_device_memory) \ - __macro(rocblas_is_user_managing_device_memory) \ - __macro(rocblas_set_workspace) \ __macro(rocblas_strsm_batched) \ __macro(rocblas_dtrsm_batched) \ __macro(rocblas_ctrsm_batched) \ diff --git a/xla/stream_executor/rocm/rocm_blas.cc b/xla/stream_executor/rocm/rocm_blas.cc index b1229df22c047..45ccf7c0f60c4 100644 --- a/xla/stream_executor/rocm/rocm_blas.cc +++ b/xla/stream_executor/rocm/rocm_blas.cc @@ -343,12 +343,6 @@ absl::Status ROCMBlas::DoBlasInternalImpl(FuncT rocblas_func, Stream *stream, << ": " << ToString(ret); } } -#if TF_ROCM_VERSION >= 60000 - if (auto *workspace = GetWorkspace(); - workspace && workspace->opaque() && workspace->size() > 0) { - (void)wrap::rocblas_set_workspace(blas_, workspace->opaque(), workspace->size()); - } -#endif ret = rocblas_func(blas_, std::forward(args)...); if (ret != rocblas_status_success) { From 9f04d2e1e878050c42c148b5aaef0ed3936dce6c Mon Sep 17 00:00:00 2001 From: Eugene Kuznetsov Date: Thu, 8 Feb 2024 13:22:05 +0000 Subject: [PATCH 5/7] Fused convolution+bias+activation --- xla/service/gpu/BUILD | 1 + xla/service/gpu/amdgpu_compiler.cc | 3 + xla/service/gpu/cudnn_fused_conv_rewriter.cc | 34 +- xla/service/gpu/cudnn_fused_conv_rewriter.h | 4 +- .../gpu/cudnn_fused_conv_rewriter_test.cc | 174 ++- xla/service/gpu/runtime/convolution_thunk.cc | 5 +- xla/stream_executor/dnn.cc | 1 + xla/stream_executor/dnn.h | 3 +- xla/stream_executor/rocm/rocm_dnn.cc | 1250 +++++++++++------ xla/stream_executor/rocm/rocm_dnn.h | 81 +- xla/stream_executor/rocm/rocm_helpers.cu.cc | 85 +- xla/tests/convolution_test.cc | 41 +- 12 files changed, 1142 insertions(+), 540 deletions(-) diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index 234319fec692f..c99e3b963100c 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -3990,6 +3990,7 @@ cc_library( ":conv_algorithm_picker", ":cublas_pad_for_gemms", ":cublas_padding_requirements", + ":cudnn_fused_conv_rewriter", ":cusolver_rewriter", ":gemm_algorithm_picker", ":gemm_rewriter", diff --git a/xla/service/gpu/amdgpu_compiler.cc b/xla/service/gpu/amdgpu_compiler.cc index 2a59d086d0223..e47054f633f03 100644 --- a/xla/service/gpu/amdgpu_compiler.cc +++ b/xla/service/gpu/amdgpu_compiler.cc @@ -31,6 +31,7 @@ limitations under the License. #include "xla/service/float_normalization.h" #include "xla/service/gpu/autotuner_util.h" #include "xla/service/gpu/conv_algorithm_picker.h" +#include "xla/service/gpu/cudnn_fused_conv_rewriter.h" #include "xla/service/gpu/cublas_pad_for_gemms.h" #include "xla/service/gpu/cublas_padding_requirements.h" #include "xla/service/gpu/cusolver_rewriter.h" @@ -109,6 +110,8 @@ absl::Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization( pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); + auto rcc = std::get(gpu_version); + pipeline.AddPass(rcc); // The conv padding/vectorization passes which we need to get rid of. They // also leave behind unnecessary tuple/get-tuple-element pairs that diff --git a/xla/service/gpu/cudnn_fused_conv_rewriter.cc b/xla/service/gpu/cudnn_fused_conv_rewriter.cc index cb8e0ec95e3b9..8efa13b9ee22e 100644 --- a/xla/service/gpu/cudnn_fused_conv_rewriter.cc +++ b/xla/service/gpu/cudnn_fused_conv_rewriter.cc @@ -96,6 +96,10 @@ bool IsNonDepthwiseConvCustomCall(const HloInstruction* instr) { return IsConvCustomCall(instr) && !IsConvDepthwise(instr); } +bool IsROCm(se::GpuComputeCapability cc) { + return std::holds_alternative(cc); +} + // elu, relu6, and leaky-relu activations are supported in cudnn via the // "runtime fusion" engine, which JIT compiles C++ code. This can be slow to // compile, so we guard it with a debug option. @@ -106,8 +110,12 @@ bool IsNonDepthwiseConvCustomCall(const HloInstruction* instr) { // Note that as of writing, xla_gpu_use_runtime_fusion is disabled by default // due to apparent bugs in cudnn 8.9.0. See debug_options_flags.cc for details. bool ShouldUseCudnnRuntimeFusion(const DebugOptions& debug_opts, - se::CudaComputeCapability cc) { - return debug_opts.xla_gpu_use_runtime_fusion() && cc.IsAtLeast(7, 5); + se::GpuComputeCapability cc) { + const auto* cuda_cc = std::get_if(&cc); + if(cuda_cc != nullptr) + return debug_opts.xla_gpu_use_runtime_fusion() && cuda_cc->IsAtLeast(7, 5); + else + return true; } bool IsSuitableForCudnnRuntimeFusion(HloInstruction* conv) { @@ -984,7 +992,7 @@ absl::StatusOr FuseSideInputAlpha(HloComputation* comp) { } absl::StatusOr FuseElu(HloComputation* comp, - se::CudaComputeCapability cc) { + se::GpuComputeCapability cc) { if (!ShouldUseCudnnRuntimeFusion(comp->parent()->config().debug_options(), cc)) { return false; @@ -1085,7 +1093,7 @@ absl::StatusOr FuseRelu(HloComputation* comp) { } absl::StatusOr FuseRelu6(HloComputation* comp, - se::CudaComputeCapability cc) { + se::GpuComputeCapability cc) { if (!ShouldUseCudnnRuntimeFusion(comp->parent()->config().debug_options(), cc)) { return false; @@ -1134,7 +1142,7 @@ absl::StatusOr FuseRelu6(HloComputation* comp, } absl::StatusOr FuseLeakyRelu(HloComputation* comp, - se::CudaComputeCapability cc) { + se::GpuComputeCapability cc) { if (!ShouldUseCudnnRuntimeFusion(comp->parent()->config().debug_options(), cc)) { return false; @@ -1254,7 +1262,10 @@ absl::StatusOr FuseConvertToF16(HloComputation* comp) { return changed; } -absl::StatusOr FuseConvertToS8(HloComputation* comp) { +absl::StatusOr FuseConvertToS8(HloComputation* comp, + se::GpuComputeCapability cc) { + if(IsROCm(cc)) + return false; bool changed = false; for (HloInstruction* instr : comp->MakeInstructionPostOrder()) { HloInstruction* gte = nullptr; @@ -1480,9 +1491,12 @@ absl::StatusOr CudnnFusedConvRewriter::Run( bool changed = false; // Rewrite FP8 convolutions and supported adjacent pointwise ops into a // ForwardGraph Custom Call. - TF_ASSIGN_OR_RETURN(changed, F8GraphConv(comp, compute_capability_)); - if (changed) { - return changed; + if(!IsROCm(compute_capability_)) { + auto cc = std::get(compute_capability_); + TF_ASSIGN_OR_RETURN(changed, F8GraphConv(comp, cc)); + if (changed) { + return changed; + } } // Fuse "inside out" starting with the operations closest to the conv. TF_ASSIGN_OR_RETURN(changed, FuseRemoveConvertInConv(comp)); @@ -1516,7 +1530,7 @@ absl::StatusOr CudnnFusedConvRewriter::Run( TF_ASSIGN_OR_RETURN(changed, FuseConvertToF16(comp)); any_changed |= changed; - TF_ASSIGN_OR_RETURN(changed, FuseConvertToS8(comp)); + TF_ASSIGN_OR_RETURN(changed, FuseConvertToS8(comp, compute_capability_)); any_changed |= changed; // f16 convs' bias+side-input can appear before or after conversion to f16. diff --git a/xla/service/gpu/cudnn_fused_conv_rewriter.h b/xla/service/gpu/cudnn_fused_conv_rewriter.h index bc7291d262a61..ff1d156525539 100644 --- a/xla/service/gpu/cudnn_fused_conv_rewriter.h +++ b/xla/service/gpu/cudnn_fused_conv_rewriter.h @@ -100,6 +100,8 @@ class CudnnFusedConvRewriter : public HloModulePass { public: explicit CudnnFusedConvRewriter(se::CudaComputeCapability cc) : compute_capability_(cc) {} + explicit CudnnFusedConvRewriter(se::RocmComputeCapability cc) + : compute_capability_(cc) {} absl::string_view name() const override { return "cudnn-fused-convolution-rewriter"; @@ -111,7 +113,7 @@ class CudnnFusedConvRewriter : public HloModulePass { const absl::flat_hash_set& execution_threads) override; private: - const se::CudaComputeCapability compute_capability_; + const se::GpuComputeCapability compute_capability_; }; } // namespace gpu diff --git a/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc b/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc index 4a55b9dc9eb40..2add39d4a3b1b 100644 --- a/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc +++ b/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -68,8 +69,16 @@ namespace m = match; using ::testing::HasSubstr; using ::testing::Not; +std::vector f16f32f64 = {"f16", "f32", "f64"}; +std::vector f16f32 = {"f16", "f32"}; + class CudnnFusedConvRewriterHloTest : public HloTestBase { public: + bool IsCuda() { return std::holds_alternative< + se::CudaComputeCapability>( + backend().default_stream_executor() + ->GetDeviceDescription().gpu_compute_capability()); + } se::CudaComputeCapability GetCudaComputeCapability() { return backend() .default_stream_executor() @@ -85,6 +94,11 @@ class CudnnFusedConvRewriterHloTest : public HloTestBase { class CudnnFusedConvRewriterTest : public GpuCodegenTest { public: + bool IsCuda() { return std::holds_alternative< + se::CudaComputeCapability>( + backend().default_stream_executor() + ->GetDeviceDescription().gpu_compute_capability()); + } se::CudaComputeCapability GetCudaComputeCapability() { return backend() .default_stream_executor() @@ -119,7 +133,7 @@ class CudnnFusedConvRewriterTest : public GpuCodegenTest { } void TestMatchWithAllTypes(absl::string_view hlo_string) { - for (absl::string_view type : {"f16", "f32", "f64"}) { + for (absl::string_view type : (IsCuda() ? f16f32f64 : f16f32)) { const std::string hlo_with_new_type = absl::StrReplaceAll(hlo_string, {{"TYPE", type}}); std::string optimized_hlo_string = GetOptimizedHlo(hlo_with_new_type); @@ -157,7 +171,7 @@ class CudnnFusedConvRewriterTest : public GpuCodegenTest { } void TestNotMatchWithAllTypes(absl::string_view hlo_string) { - for (absl::string_view type : {"f16", "f32", "f64"}) { + for (absl::string_view type : (IsCuda() ? f16f32f64 : f16f32)) { const std::string hlo_with_new_type = absl::StrReplaceAll(hlo_string, {{"TYPE", type}}); std::string optimized_hlo_string = GetOptimizedHlo(hlo_with_new_type); @@ -170,6 +184,8 @@ class CudnnFusedConvRewriterTest : public GpuCodegenTest { void TestF8(std::string pre_hlo_string, std::string custom_call_string, std::string serialized_graph_string) { + if(!IsCuda()) + return; if (GetCudaComputeCapability().IsAtLeast( se::CudaComputeCapability::HOPPER)) { // On Hopper and newer architectures, test numerical correctness and @@ -244,6 +260,23 @@ class CudnnFusedConvRewriterTest : public GpuCodegenTest { } }; +#if GOOGLE_CUDA + #if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8900) + #define MAYBE_SKIP_TEST(CAUSE) \ + do { \ + if(absl::string_view(CAUSE) == "F8") \ + GTEST_SKIP() << "FP8 convolutions require CUDA 12 and cuDNN 8.9.";\ + } while(0) + #else + #define MAYBE_SKIP_TEST(CAUSE) + #endif +#else +#define MAYBE_SKIP_TEST(CAUSE) \ + do { \ + GTEST_SKIP() << "ROCm does not support " CAUSE " fusion"; \ + } while(0) +#endif + TEST_F(CudnnFusedConvRewriterTest, TestConvOnly) { // max(0, conv(x, w)); TestMatchWithAllTypes(R"( @@ -298,6 +331,83 @@ TEST_F(CudnnFusedConvRewriterTest, TestBias) { })"); } +TEST_F(CudnnFusedConvRewriterTest, Test3D) { + // max(0, conv(x, w) + bias); + std::string body = R"( + HloModule Test + + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,3,5,7,64] broadcast(zero), dimensions={} + + input = TYPE[1,3,5,7,64] parameter(0) + filter = TYPE[3,3,3,64,64] parameter(1) + bias = TYPE[64] parameter(2) + + conv = TYPE[1,3,5,7,64] convolution(input, filter), window={size=3x3x3 pad=1_1x1_1x1_1}, dim_labels=b012f_012io->b012f, feature_group_count=1 + broadcasted_bias = TYPE[1,3,5,7,64] broadcast(bias), dimensions={4} + add1 = TYPE[1,3,5,7,64] add(conv, broadcasted_bias) + )"; + + std::string relu = R"( + ROOT relu = TYPE[1,3,5,7,64] maximum(zeros, add1) + })"; + + std::string elu = R"( + cmp = pred[1,3,5,7,64] compare(add1, zeros), direction=GT + expm1 = TYPE[1,3,5,7,64] exponential-minus-one(add1) + ROOT elu = TYPE[1,3,5,7,64] select(cmp, add1, expm1) + })"; + + TestMatchWithAllTypes(body+relu); + if (!IsCuda()) + TestMatchWithAllTypes(body+elu); +} + +TEST_F(CudnnFusedConvRewriterTest, TestBiasMultiCall) { + + // max(0, conv(x, w) + bias); + std::string code = R"( + HloModule Test + + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,<<>>,64] broadcast(zero), dimensions={} + + input = TYPE[1,<<>>,64] parameter(0) + filter = TYPE[3,3,64,64] parameter(1) + bias = TYPE[64] parameter(2) + + conv = TYPE[1,<<>>,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 + broadcasted_bias = TYPE[1,<<>>,64] broadcast(bias), dimensions={3} + add1 = TYPE[1,<<>>,64] add(conv, broadcasted_bias) + ROOT relu = TYPE[1,<<>>,64] maximum(zeros, add1) + })"; + absl::flat_hash_map replacements; + replacements["<<>>"] = "3,3"; + TestMatchWithAllTypes(absl::StrReplaceAll(code, replacements)); + replacements["<<>>"] = "5,5"; + TestMatchWithAllTypes(absl::StrReplaceAll(code, replacements)); + replacements["<<>>"] = "3,3"; + TestMatchWithAllTypes(absl::StrReplaceAll(code, replacements)); +} + +TEST_F(CudnnFusedConvRewriterTest, TestBiasNoRelu) { + // conv(x, w) + bias; + TestMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + input = TYPE[1,3,3,64] parameter(0) + filter = TYPE[3,3,64,64] parameter(1) + bias = TYPE[64] parameter(2) + + conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 + broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3} + ROOT add1 = TYPE[1,3,3,64] add(conv, broadcasted_bias) + })"); +} + TEST_F(CudnnFusedConvRewriterTest, DontFuseBiasWithDepthwiseConv) { // conv(x, w) + bias; TestNotMatchWithAllTypes(R"( @@ -365,7 +475,7 @@ TEST_F(CudnnFusedConvRewriterTest, DontFuseEluWithDepthwiseConv) { } TEST_F(CudnnFusedConvRewriterTest, TestRelu6) { - if (!GetCudaComputeCapability().IsAtLeast( + if (IsCuda() && !GetCudaComputeCapability().IsAtLeast( se::CudaComputeCapability::AMPERE)) { GTEST_SKIP() << "Conv-Bias-Relu6 fusion is supported and recommended with " "the Nvidia Ampere+ GPUs."; @@ -393,12 +503,11 @@ TEST_F(CudnnFusedConvRewriterTest, TestRelu6) { // number of input/output channels. Check that we don't try to run this conv // with runtime fusion (or, if we do, that it works!). TEST_F(CudnnFusedConvRewriterTest, TestRelu6OddChannels) { - if (!GetCudaComputeCapability().IsAtLeast( + if (IsCuda() && !GetCudaComputeCapability().IsAtLeast( se::CudaComputeCapability::AMPERE)) { GTEST_SKIP() << "Conv-Bias-Relu6 fusion is supported and recommended with " "the Nvidia Ampere+ GPUs."; } - TestMatchWithAllTypes(R"( HloModule Test ENTRY Test { @@ -415,7 +524,7 @@ TEST_F(CudnnFusedConvRewriterTest, TestRelu6OddChannels) { } TEST_F(CudnnFusedConvRewriterTest, TestLeakyRelu) { - if (!GetCudaComputeCapability().IsAtLeast( + if (IsCuda() && !GetCudaComputeCapability().IsAtLeast( se::CudaComputeCapability::AMPERE)) { GTEST_SKIP() << "Conv-Bias-LeakyRelu fusion is supported and recommended with " @@ -730,9 +839,7 @@ TEST_F(CudnnFusedConvRewriterTest, TestPreservesFeatureGroupCount) { } TEST_F(CudnnFusedConvRewriterTest, TestConvF8) { -#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8900) - GTEST_SKIP() << "FP8 convolutions require CUDA 12 and cuDNN 8.9."; -#endif + MAYBE_SKIP_TEST("F8"); TestF8( // pre_hlo R"( @@ -755,9 +862,7 @@ TEST_F(CudnnFusedConvRewriterTest, TestConvF8) { } TEST_F(CudnnFusedConvRewriterTest, TestConvScaledOutputF8) { -#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8900) - GTEST_SKIP() << "FP8 convolutions require CUDA 12 and cuDNN 8.9."; -#endif + MAYBE_SKIP_TEST("F8"); TestF8( // pre_hlo R"( @@ -791,9 +896,7 @@ TEST_F(CudnnFusedConvRewriterTest, TestConvScaledOutputF8) { } TEST_F(CudnnFusedConvRewriterTest, TestConvInvscaledOutputF8) { -#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8900) - GTEST_SKIP() << "FP8 convolutions require CUDA 12 and cuDNN 8.9."; -#endif + MAYBE_SKIP_TEST("F8"); TestF8( // pre_hlo R"( @@ -827,9 +930,7 @@ TEST_F(CudnnFusedConvRewriterTest, TestConvInvscaledOutputF8) { } TEST_F(CudnnFusedConvRewriterTest, TestConvScaledF8Parameterized) { -#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8900) - GTEST_SKIP() << "FP8 convolutions require CUDA 12 and cuDNN 8.9."; -#endif + MAYBE_SKIP_TEST("F8"); TestF8Parameterized( // pre_hlo R"( @@ -869,9 +970,7 @@ TEST_F(CudnnFusedConvRewriterTest, TestConvScaledF8Parameterized) { } TEST_F(CudnnFusedConvRewriterTest, TestConvScaledBiasF8) { -#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8900) - GTEST_SKIP() << "FP8 convolutions require CUDA 12 and cuDNN 8.9."; -#endif + MAYBE_SKIP_TEST("F8"); TestF8( // pre_hlo R"( @@ -913,9 +1012,7 @@ TEST_F(CudnnFusedConvRewriterTest, TestConvScaledBiasF8) { } TEST_F(CudnnFusedConvRewriterTest, TestConvScaledReluF8) { -#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8900) - GTEST_SKIP() << "FP8 convolutions require CUDA 12 and cuDNN 8.9."; -#endif + MAYBE_SKIP_TEST("F8"); TestF8( // pre_hlo R"( @@ -952,6 +1049,7 @@ TEST_F(CudnnFusedConvRewriterTest, TestConvScaledReluF8) { } TEST_F(CudnnFusedConvRewriterTest, TestConvAmaxF8) { + MAYBE_SKIP_TEST("F8"); TestF8( // pre_hlo R"( @@ -1001,6 +1099,7 @@ TEST_F(CudnnFusedConvRewriterTest, TestConvAmaxF8) { } TEST_F(CudnnFusedConvRewriterTest, TestConvReluAmaxF8) { + MAYBE_SKIP_TEST("F8"); TestF8( // pre_hlo R"( @@ -1053,9 +1152,7 @@ TEST_F(CudnnFusedConvRewriterTest, TestConvReluAmaxF8) { } TEST_F(CudnnFusedConvRewriterTest, TestConvScaledOutputMultipleUsersF8) { -#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8900) - GTEST_SKIP() << "FP8 convolutions require CUDA 12 and cuDNN 8.9."; -#endif + MAYBE_SKIP_TEST("F8"); TestF8( // pre_hlo R"( @@ -1095,9 +1192,7 @@ TEST_F(CudnnFusedConvRewriterTest, TestConvScaledOutputMultipleUsersF8) { } TEST_F(CudnnFusedConvRewriterTest, TestConvScaledOutputUnsupportedUserF8) { -#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8900) - GTEST_SKIP() << "FP8 convolutions require CUDA 12 and cuDNN 8.9."; -#endif + MAYBE_SKIP_TEST("F8"); TestF8( // pre_hlo R"( @@ -1133,6 +1228,7 @@ TEST_F(CudnnFusedConvRewriterTest, TestConvScaledOutputUnsupportedUserF8) { } TEST_F(CudnnFusedConvRewriterTest, TestConvInt8ToInt8) { + MAYBE_SKIP_TEST("I8"); // max(0, clamp(conv(x, w)))); for int8_t TestClamp( // pre_hlo @@ -1167,6 +1263,7 @@ TEST_F(CudnnFusedConvRewriterTest, TestConvInt8ToInt8) { } TEST_F(CudnnFusedConvRewriterHloTest, TestConvInt8ToFloat) { + MAYBE_SKIP_TEST("I8"); const std::string module_str = R"( HloModule Test @@ -1198,6 +1295,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, TestConvInt8ToFloat) { } TEST_F(CudnnFusedConvRewriterHloTest, TestConvInt8ToInt8BiasSideInput) { + MAYBE_SKIP_TEST("I8"); const std::string module_str = R"( HloModule Test @@ -1238,6 +1336,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, TestConvInt8ToInt8BiasSideInput) { } TEST_F(CudnnFusedConvRewriterHloTest, TestReluAfterConvert) { + MAYBE_SKIP_TEST("I8"); const std::string module_str = R"( HloModule Test @@ -1286,6 +1385,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, TestReluAfterConvert) { } TEST_F(CudnnFusedConvRewriterHloTest, TestConvInt8ToFloatBiasSideInput) { + MAYBE_SKIP_TEST("I8"); const std::string module_str = R"( HloModule Test @@ -1333,6 +1433,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, TestConvInt8ToFloatBiasSideInput) { // reshape(side_input * alpha). // Make sure we can pattern-match this. TEST_F(CudnnFusedConvRewriterHloTest, Int8SideInputWithScaleAndReshape) { + MAYBE_SKIP_TEST("I8"); const std::string module_str = R"( HloModule Test @@ -1386,6 +1487,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, Int8SideInputWithScaleAndReshape) { } TEST_F(CudnnFusedConvRewriterHloTest, FuseAlpha) { + MAYBE_SKIP_TEST("I8"); const std::string module_str = R"( HloModule Test @@ -2054,6 +2156,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseToS8IfMultipleUsers) { } TEST_F(CudnnFusedConvRewriterHloTest, RemoveConvertByFusingS32ToF32) { + MAYBE_SKIP_TEST("I8"); const std::string_view module_str = R"( HloModule Test @@ -2079,6 +2182,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, RemoveConvertByFusingS32ToF32) { } TEST_F(CudnnFusedConvRewriterHloTest, RemoveConvertByFusingS8ToF32) { + MAYBE_SKIP_TEST("I8"); const std::string_view module_str = R"( HloModule Test @@ -2104,6 +2208,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, RemoveConvertByFusingS8ToF32) { } TEST_F(CudnnFusedConvRewriterHloTest, RemoveConvertByFusingF32ToS8) { + MAYBE_SKIP_TEST("I8"); const std::string_view module_str = R"( HloModule Test @@ -2608,6 +2713,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, F32ConstantsNotLosslesslyConvertible) { } TEST_F(CudnnFusedConvRewriterHloTest, FuseReluBeforeConvert) { + MAYBE_SKIP_TEST("I8"); const std::string module_str = R"( HloModule Test @@ -2663,6 +2769,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseReluBeforeConvert) { } TEST_F(CudnnFusedConvRewriterHloTest, BiasTypeMatchesConvTypeIfFp) { + MAYBE_SKIP_TEST("F64"); const std::string module_str = R"( HloModule Test @@ -2699,6 +2806,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, BiasTypeMatchesConvTypeIfFp) { } TEST_F(CudnnFusedConvRewriterTest, TestFusedConvInt8ToInt8) { + MAYBE_SKIP_TEST("I8"); // clamp(max(0, conv(x, w)+bias)); for int8_t TestClamp( // pre_hlo @@ -2740,6 +2848,7 @@ TEST_F(CudnnFusedConvRewriterTest, TestFusedConvInt8ToInt8) { // Disabled per b/190854862 or nvbugs/3326122. TEST_F(CudnnFusedConvRewriterTest, DISABLED_TestFusedConvInt8ToFloat) { + MAYBE_SKIP_TEST("I8"); // max(0, convert(conv(int8_x), // conv(int8_w))+float_bias)); int8_t to float via bias. TestClamp( @@ -2775,6 +2884,7 @@ TEST_F(CudnnFusedConvRewriterTest, DISABLED_TestFusedConvInt8ToFloat) { TEST_F(CudnnFusedConvRewriterTest, TestFusedConvWithScaledInt8SideInputBiasInt8ToInt8) { + MAYBE_SKIP_TEST("I8"); // clamp(max(0, alpha_conv * conv(x, w) + alpha_side * // convert(int8_side_input) + bias)); for int8_t TestClamp( @@ -2826,6 +2936,7 @@ TEST_F(CudnnFusedConvRewriterTest, TEST_F(CudnnFusedConvRewriterTest, TestFusedConvWithScaledFloatSideInputBiasInt8ToInt8) { + MAYBE_SKIP_TEST("I8"); // From: // convert(clamp(max(0, alpha_conv * conv(x, w) + alpha_side * // float_side_input + bias))); To: convert(clamp(conv(int8_x, int8_w, @@ -2878,6 +2989,7 @@ TEST_F(CudnnFusedConvRewriterTest, TEST_F(CudnnFusedConvRewriterTest, TestFusedConvWithScaledInt8SideInputBiasInt8ToFloat) { + MAYBE_SKIP_TEST("I8"); // From: // clamp(max(0, alpha_conv * conv(x, w) + alpha_side * // convert(int8_side_input) + bias)); To: clamp(conv(int8_x, int8_w, @@ -2928,6 +3040,7 @@ TEST_F(CudnnFusedConvRewriterTest, } TEST_F(CudnnFusedConvRewriterTest, TestConvInt8ToInt8NoClamp) { + MAYBE_SKIP_TEST("I8"); // Check that integer convolution without clamp to int8_t is not allowed. // convert(custom_call(int32_x, int32_w, // cudnnConvolutionForward)) @@ -2951,6 +3064,7 @@ TEST_F(CudnnFusedConvRewriterTest, TestConvInt8ToInt8NoClamp) { } TEST_F(CudnnFusedConvRewriterTest, TestFusedConvInt8ToInt8NoClamp) { + MAYBE_SKIP_TEST("I8"); // Although bias and so on are fused with forward convolution, // it is still not allowed if the output is not clampped/converted to int8_t // max(0, alpha_conv * conv(x, w) + alpha_side * side_input + bias); for diff --git a/xla/service/gpu/runtime/convolution_thunk.cc b/xla/service/gpu/runtime/convolution_thunk.cc index 6e8158d866aaf..0036529a4acd0 100644 --- a/xla/service/gpu/runtime/convolution_thunk.cc +++ b/xla/service/gpu/runtime/convolution_thunk.cc @@ -87,6 +87,9 @@ absl::Status ConvolutionThunk::ExecuteOnStream(const ExecuteParams& params) { TF_ASSIGN_OR_RETURN(se::dnn::DataType input_type, GetDNNDataTypeFromPrimitiveType(config_.input_type)); + TF_ASSIGN_OR_RETURN(se::dnn::DataType output_type, + GetDNNDataTypeFromPrimitiveType(config_.output_type)); + TF_ASSIGN_OR_RETURN(auto dnn, se::dnn::internal::GetDnnFromStream(params.stream)); se::OwningScratchAllocator<> scratch_allocator( @@ -95,7 +98,7 @@ absl::Status ConvolutionThunk::ExecuteOnStream(const ExecuteParams& params) { std::vector profile_results; dnn->GetMIOpenConvolveAlgorithms( - kind, input_type, params.stream, config_.input_descriptor, + kind, input_type, output_type, params.stream, config_.input_descriptor, conv_params.input_buf, config_.filter_descriptor, conv_params.filter_buf, config_.output_descriptor, conv_params.output_buf, config_.conv_desc, &scratch_allocator, diff --git a/xla/stream_executor/dnn.cc b/xla/stream_executor/dnn.cc index ecd7555c94d7d..26cdd562676bd 100644 --- a/xla/stream_executor/dnn.cc +++ b/xla/stream_executor/dnn.cc @@ -287,6 +287,7 @@ DnnSupport::FusedMHABackwardRunnerFromDesc( bool DnnSupport::GetMIOpenConvolveAlgorithms( dnn::ConvolutionKind /*kind*/, dnn::DataType /*element_type*/, + dnn::DataType /*output_type*/, Stream* /*stream*/, const dnn::BatchDescriptor& /*input_descriptor*/, DeviceMemoryBase input_data, const dnn::FilterDescriptor& /*filter_descriptor*/, diff --git a/xla/stream_executor/dnn.h b/xla/stream_executor/dnn.h index 41fb3e34c8fc0..c3f8e4fab5bb3 100644 --- a/xla/stream_executor/dnn.h +++ b/xla/stream_executor/dnn.h @@ -1767,7 +1767,8 @@ class DnnSupport { dnn::FMHAMaskKind mask_type); virtual bool GetMIOpenConvolveAlgorithms( - ConvolutionKind kind, DataType element_type, Stream* stream, + ConvolutionKind kind, DataType element_type, DataType output_type, + Stream* stream, const BatchDescriptor& input_descriptor, DeviceMemoryBase input_data, const FilterDescriptor& filter_descriptor, DeviceMemoryBase filter_data, const BatchDescriptor& output_descriptor, DeviceMemoryBase output_data, diff --git a/xla/stream_executor/rocm/rocm_dnn.cc b/xla/stream_executor/rocm/rocm_dnn.cc index fd3bcafa50139..04619cfdb9e1f 100644 --- a/xla/stream_executor/rocm/rocm_dnn.cc +++ b/xla/stream_executor/rocm/rocm_dnn.cc @@ -42,6 +42,7 @@ limitations under the License. #include "xla/stream_executor/rocm/rocm_platform_id.h" #include "xla/stream_executor/scratch_allocator.h" #include "xla/stream_executor/stream.h" +#include "xla/stream_executor/stream_interface.h" #include "xla/stream_executor/stream_executor.h" #include "xla/tsl/util/determinism.h" #include "xla/tsl/util/env_var.h" @@ -112,6 +113,8 @@ string ToString(miopenStatus_t status) { return "miopenStatusNotImplemented"; case miopenStatusUnknownError: return "miopenStatusUnknownError"; + case miopenStatusUnsupportedOp: + return "miopenStatusUnsupportedOp"; default: return absl::StrCat("(status), ">"); @@ -377,6 +380,8 @@ namespace wrap { __macro(miopenConvolutionForwardGetSolutionWorkspaceSize) \ __macro(miopenConvolutionForwardCompileSolution) \ __macro(miopenConvolutionForwardImmediate) \ + __macro(miopenConvolutionForwardBias) \ + __macro(miopenConvolutionBiasActivationForward) \ __macro(miopenConvolutionBackwardDataGetSolutionCount) \ __macro(miopenConvolutionBackwardDataGetSolution) \ __macro(miopenConvolutionBackwardDataGetSolutionWorkspaceSize) \ @@ -495,11 +500,13 @@ namespace wrap { __macro(miopenExecuteFusionPlan) \ __macro(miopenDestroyOperatorArgs) \ __macro(miopenDestroyFusionPlan) \ + __macro(miopenConvolutionBiasActivationForward) \ __macro(miopenConvolutionForwardGetSolutionCount) \ __macro(miopenConvolutionForwardGetSolution) \ __macro(miopenConvolutionForwardGetSolutionWorkspaceSize) \ __macro(miopenConvolutionForwardCompileSolution) \ __macro(miopenConvolutionForwardImmediate) \ + __macro(miopenConvolutionForwardBias) \ __macro(miopenConvolutionBackwardDataGetSolutionCount) \ __macro(miopenConvolutionBackwardDataGetSolution) \ __macro(miopenConvolutionBackwardDataGetSolutionWorkspaceSize) \ @@ -565,7 +572,7 @@ uint64_t GetHashValue(miopenConvolutionDescriptor_t conv_desc) { uint64_t hash_value = tsl::hash()(c_mode); auto hash64Combine = [&hash_value](int element) { - tsl::Hash64Combine(hash_value, tsl::hash()(element)); + hash_value = tsl::Hash64Combine(hash_value, tsl::hash()(element)); }; std::for_each(pad.begin(), pad.end(), hash64Combine); std::for_each(stride.begin(), stride.end(), hash64Combine); @@ -597,9 +604,11 @@ class CachedFusionPlans { auto it = cached_plans.find(hash); if (it != cached_plans.end()) { + VLOG(2) << "Found a cached plan for " << hash; *fusion_plan = it->second; found_cached_plan = true; } else { + VLOG(2) << "Creating a new plan for " << hash; auto status = wrap::miopenCreateFusionPlan(fusion_plan, fusion_direction, input_descriptor); if (status != miopenStatusSuccess) { @@ -675,6 +684,7 @@ dnn::ProfileResult GetProfileResultFromConvAlgoPerf( int64_t algo_id; switch (kind) { case dnn::ConvolutionKind::FORWARD: + case dnn::ConvolutionKind::FORWARD_BIAS_ACTIVATION: algo_id = algorithm.fwd_algo; break; case dnn::ConvolutionKind::BACKWARD_DATA: @@ -798,77 +808,107 @@ absl::StatusOr MIOpenSupport::GetVersion() { return stream_executor::dnn::VersionInfo(1, 3, 0); } -// Turns a BatchDescriptor structure into a miopen tensor handle within a scope. -class ScopedTensorDescriptor { - public: - ScopedTensorDescriptor(const BatchDescriptor& batch_descriptor, - miopenDataType_t elem_type) - : handle_(nullptr) { - auto status = wrap::miopenCreateTensorDescriptor(&handle_); - if (status != miopenStatusSuccess) { - LOG(FATAL) << "could not create miopen tensor descriptor: " - << ToString(status); - } +template +miopenStatus_t miDestroyObject(T obj) { return miopenStatusSuccess; } - switch (batch_descriptor.layout()) { - case dnn::DataLayout::kBatchYXDepth: - case dnn::DataLayout::kBatchDepthYX: { - const int nd = batch_descriptor.ndims() + 2; +template<> miopenStatus_t miDestroyObject(miopenTensorDescriptor_t obj) { + return wrap::miopenDestroyTensorDescriptor(obj); +} - // MIOpen requires the strides and dims to be ordered as BDYX. - std::vector strides64 = - batch_descriptor.full_strides(dnn::DataLayout::kBatchDepthYX); - std::vector dims64 = - batch_descriptor.full_dims(dnn::DataLayout::kBatchDepthYX); +template<> miopenStatus_t miDestroyObject(miopenConvolutionDescriptor_t obj) { + return wrap::miopenDestroyConvolutionDescriptor(obj); +} - // MIOpen requires arrays of ints. - std::vector strides(nd); - std::vector dims(nd); - std::transform(strides64.cbegin(), strides64.cend(), strides.begin(), - &CheckedNarrowing); - std::transform(dims64.cbegin(), dims64.cend(), dims.begin(), - &CheckedNarrowing); - status = wrap::miopenSetTensorDescriptor(handle_, elem_type, nd, - dims.data(), strides.data()); +template<> miopenStatus_t miDestroyObject(miopenPoolingDescriptor_t obj) { + return wrap::miopenDestroyPoolingDescriptor(obj); +} - if (status != miopenStatusSuccess) { - LOG(FATAL) << "could not convert BatchDescriptor " - << batch_descriptor.ToString() - << " to miopen tensor descriptor: " << ToString(status); - } - } break; - default: - LOG(FATAL) << "Unsupported tensor format " - << DataLayoutString(batch_descriptor.layout()); - break; - } +template<> miopenStatus_t miDestroyObject(miopenLRNDescriptor_t obj) { + return wrap::miopenDestroyLRNDescriptor(obj); +} + +template +struct ScopedDescriptor { + ScopedDescriptor() : handle_(nullptr) {} + + ScopedDescriptor(ScopedDescriptor&& other) { + handle_ = other.handle_; + other.handle_ = nullptr; } - ~ScopedTensorDescriptor() { - auto status = wrap::miopenDestroyTensorDescriptor(handle_); + ~ScopedDescriptor() { + if(handle_ != nullptr) + return; + + auto status = miDestroyObject(handle_);//wrap::miopenDestroyTensorDescriptor(handle_); if (status != miopenStatusSuccess) { LOG(ERROR) << "could not destroy miopen tensor descriptor: " << ToString(status); } } - miopenTensorDescriptor_t handle() const { return handle_; } + T handle() const { return handle_; } - private: - miopenTensorDescriptor_t handle_; // Owned. + T handle_; // Owned. - ScopedTensorDescriptor(const ScopedTensorDescriptor&) = delete; - void operator=(const ScopedTensorDescriptor&) = delete; + ScopedDescriptor(const ScopedDescriptor&) = delete; + void operator=(const ScopedDescriptor&) = delete; }; -// Turns a FilterDescriptor structure into a miopen filter handle within a -// scope. -class ScopedFilterDescriptor { - public: - ScopedFilterDescriptor(const FilterDescriptor& filter_descriptor, - miopenDataType_t elem_type) - : handle_(nullptr) { - auto status = wrap::miopenCreateTensorDescriptor(&handle_); +using ScopedTensorDescriptor = ScopedDescriptor; +using ScopedFilterDescriptor = ScopedDescriptor; +using ScopedConvolutionDescriptor = ScopedDescriptor; +using ScopedPoolingDescriptor = ScopedDescriptor; +using ScopedNormalizeDescriptor = ScopedDescriptor; + +absl::StatusOr scope(const BatchDescriptor& batch_descriptor, + miopenDataType_t data_type) { + ScopedTensorDescriptor obj; + auto status = wrap::miopenCreateTensorDescriptor(&obj.handle_); + if (status != miopenStatusSuccess) { + return absl::InternalError("could not create miopen tensor descriptor: " + + ToString(status)); + } + + switch (batch_descriptor.layout()) { + case dnn::DataLayout::kBatchYXDepth: + case dnn::DataLayout::kBatchDepthYX: { + const int nd = batch_descriptor.ndims() + 2; + + // MIOpen requires the strides and dims to be ordered as BDYX. + std::vector strides64 = + batch_descriptor.full_strides(dnn::DataLayout::kBatchDepthYX); + std::vector dims64 = + batch_descriptor.full_dims(dnn::DataLayout::kBatchDepthYX); + + // MIOpen requires arrays of ints. + std::vector strides(nd); + std::vector dims(nd); + std::transform(strides64.cbegin(), strides64.cend(), strides.begin(), + &CheckedNarrowing); + std::transform(dims64.cbegin(), dims64.cend(), dims.begin(), + &CheckedNarrowing); + status = wrap::miopenSetTensorDescriptor(obj.handle_, data_type, nd, + dims.data(), strides.data()); + + if (status != miopenStatusSuccess) { + return absl::InternalError("could not convert BatchDescriptor " + + batch_descriptor.ToString() + + " to miopen tensor descriptor: " + ToString(status)); + } + } break; + default: + return absl::InternalError("Unsupported tensor format " + + DataLayoutString(batch_descriptor.layout())); + break; + } + return obj; +} + +absl::StatusOr scope(const FilterDescriptor& filter_descriptor, + miopenDataType_t data_type) { + ScopedFilterDescriptor obj; + auto status = wrap::miopenCreateTensorDescriptor(&obj.handle_); if (status != miopenStatusSuccess) { LOG(FATAL) << "could not create miopen filter descriptor: " << ToString(status); @@ -925,7 +965,7 @@ class ScopedFilterDescriptor { &CheckedNarrowing); absl::c_transform(dims64, std::back_inserter(dims), &CheckedNarrowing); - status = wrap::miopenSetTensorDescriptor(handle_, elem_type, nd, + status = wrap::miopenSetTensorDescriptor(obj.handle_, data_type, nd, dims.data(), strides.data()); if (status != miopenStatusSuccess) { @@ -939,35 +979,12 @@ class ScopedFilterDescriptor { << FilterLayoutString(filter_descriptor.layout()); break; } - } - - ~ScopedFilterDescriptor() { - auto status = wrap::miopenDestroyTensorDescriptor(handle_); - if (status != miopenStatusSuccess) { - LOG(ERROR) << "could not destroy miopen filter descriptor: " - << ToString(status); - } - } - - miopenTensorDescriptor_t handle() const { return handle_; } - - private: - // miopen filter descriptor this object creates. Owned. - miopenTensorDescriptor_t handle_; - - ScopedFilterDescriptor(const ScopedFilterDescriptor&) = delete; - void operator=(const ScopedFilterDescriptor&) = delete; -}; + return obj; +} -// Turns a ConvolutionDescriptor structure into a miopen convolution handle -// within a scope. -class ScopedConvolutionDescriptor { - public: - ScopedConvolutionDescriptor( - const ConvolutionDescriptor& convolution_descriptor, - miopenDataType_t data_type) - : handle_(nullptr) { - auto status = wrap::miopenCreateConvolutionDescriptor(&handle_); +absl::StatusOr scope(const ConvolutionDescriptor& convolution_descriptor) { + ScopedConvolutionDescriptor obj; + auto status = wrap::miopenCreateConvolutionDescriptor(&obj.handle_); if (status != miopenStatusSuccess) { LOG(FATAL) << "could not create miopen convolution descriptor: " << ToString(status); @@ -993,7 +1010,7 @@ class ScopedConvolutionDescriptor { &CheckedNarrowing); status = wrap::miopenInitConvolutionNdDescriptor( - handle_, convolution_descriptor.ndims(), padding.data(), strides.data(), + obj.handle_, convolution_descriptor.ndims(), padding.data(), strides.data(), upscale.data(), miopenConvolution); if (status != miopenStatusSuccess) { LOG(FATAL) << "could not set miopen convolution descriptor: " @@ -1003,7 +1020,7 @@ class ScopedConvolutionDescriptor { VLOG(2) << "Requesting grouped convolution: " << convolution_descriptor.group_count(); status = wrap::miopenSetConvolutionGroupCount( - handle_, convolution_descriptor.group_count()); + obj.handle_, convolution_descriptor.group_count()); if (status != miopenStatusSuccess) { LOG(FATAL) << "could not set miopen convolution group count: " << ToString(status); @@ -1012,38 +1029,19 @@ class ScopedConvolutionDescriptor { #if (TF_ROCM_VERSION >= 50300) if (RequireMIOpenDeterminism()) { status = wrap::miopenSetConvolutionAttribute( - handle_, MIOPEN_CONVOLUTION_ATTRIB_DETERMINISTIC, 1); + obj.handle_, MIOPEN_CONVOLUTION_ATTRIB_DETERMINISTIC, 1); if (status != miopenStatusSuccess) { LOG(FATAL) << "could not set miopen convolution attribute: " << ToString(status); } } #endif - } - ~ScopedConvolutionDescriptor() { - auto status = wrap::miopenDestroyConvolutionDescriptor(handle_); - if (status != miopenStatusSuccess) { - LOG(ERROR) << "could not destroy miopen convolution descriptor: " - << ToString(status); - } - } - - miopenConvolutionDescriptor_t handle() const { return handle_; } - - private: - miopenConvolutionDescriptor_t handle_; // Owned. - - ScopedConvolutionDescriptor(const ScopedConvolutionDescriptor&) = delete; - void operator=(const ScopedConvolutionDescriptor&) = delete; -}; + return obj; +} -// Turns a PoolingDescriptor structure into a miopen pooling descriptor handle -// within a scope. -class ScopedPoolingDescriptor { - public: - ScopedPoolingDescriptor(const PoolingDescriptor& pooling_descriptor) - : handle_(nullptr) { - auto status = wrap::miopenCreatePoolingDescriptor(&handle_); +absl::StatusOr scope(const PoolingDescriptor& pooling_descriptor) { + ScopedPoolingDescriptor obj; + auto status = wrap::miopenCreatePoolingDescriptor(&obj.handle_); if (status != miopenStatusSuccess) { LOG(FATAL) << "could not create miopen pooling descriptor: " << ToString(status); @@ -1065,7 +1063,7 @@ class ScopedPoolingDescriptor { &CheckedNarrowing); status = wrap::miopenSetNdPoolingDescriptor( - handle_, + obj.handle_, (pooling_descriptor.mode() == dnn::PoolingMode::kMaximum ? miopenPoolingMax : miopenPoolingAverage), @@ -1075,36 +1073,19 @@ class ScopedPoolingDescriptor { // API assumes all input indexes to be the same type. Since a tensor // descriptor can only use int32 type, the index type here need to be // aligned with the tensor index type of the (input) tensor descritptor - status = wrap::miopenSetPoolingIndexType(handle_, miopenIndexUint32); + status = wrap::miopenSetPoolingIndexType(obj.handle_, miopenIndexUint32); if (status != miopenStatusSuccess) { LOG(FATAL) << "could not set miopen pooling descriptor: " << ToString(status); } - } - ~ScopedPoolingDescriptor() { - auto status = wrap::miopenDestroyPoolingDescriptor(handle_); - if (status != miopenStatusSuccess) { - LOG(ERROR) << "could not destroy miopen pooling descriptor: " - << ToString(status); - } - } - - miopenPoolingDescriptor_t handle() const { return handle_; } - - private: - miopenPoolingDescriptor_t handle_; // Owned. + return obj; +} - ScopedPoolingDescriptor(const ScopedPoolingDescriptor&) = delete; - void operator=(const ScopedPoolingDescriptor&) = delete; -}; -// Turns a NormalizeDescriptor structure into a miopen LRN descriptor handle. -class ScopedNormalizeDescriptor { - public: - ScopedNormalizeDescriptor(const NormalizeDescriptor& normalize_descriptor) - : handle_(nullptr) { - auto status = wrap::miopenCreateLRNDescriptor(&handle_); +absl::StatusOr scope(const NormalizeDescriptor& normalize_descriptor) { + ScopedNormalizeDescriptor obj; + auto status = wrap::miopenCreateLRNDescriptor(&obj.handle_); if (status != miopenStatusSuccess) { LOG(FATAL) << "could not create miopen LRN descriptor: " << ToString(status); @@ -1130,93 +1111,84 @@ class ScopedNormalizeDescriptor { double lrn_beta = normalize_descriptor.beta(); double lrn_k = normalize_descriptor.bias(); - status = wrap::miopenSetLRNDescriptor(handle_, miopenLRNCrossChannel, lrn_N, + status = wrap::miopenSetLRNDescriptor(obj.handle_, miopenLRNCrossChannel, lrn_N, lrn_alpha, lrn_beta, lrn_k); if (status != miopenStatusSuccess) { LOG(FATAL) << "could not set miopen LRN descriptor: " << ToString(status); } - } - - ~ScopedNormalizeDescriptor() { - auto status = wrap::miopenDestroyLRNDescriptor(handle_); - if (status != miopenStatusSuccess) { - LOG(ERROR) << "could not destroy miopen LRN descriptor: " - << ToString(status); - } - } - - miopenLRNDescriptor_t handle() const { return handle_; } - - private: - miopenLRNDescriptor_t handle_; // Owned. - - ScopedNormalizeDescriptor(const ScopedNormalizeDescriptor&) = delete; - void operator=(const ScopedNormalizeDescriptor&) = delete; -}; + return obj; +} // Turns a activation mode into a miopen activation mode descriptor with a scope // around it -class ScopedActivationDescriptor { - public: - ScopedActivationDescriptor(dnn::ActivationMode activation_mode) - : handle_(nullptr), - miopen_activation_mode_(miopenActivationPASTHRU), - alpha_(0.0), - beta_(0.0), - gamma_(0.0) { - auto status = wrap::miopenCreateActivationDescriptor(&handle_); +struct ScopedActivationDescriptor: ScopedDescriptor +{ + static absl::StatusOr + Create(dnn::ActivationMode activation_mode, double alpha=0.0) { + ScopedActivationDescriptor obj; + obj.alpha_ = alpha; + auto status = wrap::miopenCreateActivationDescriptor(&obj.handle_); if (status != miopenStatusSuccess) { - LOG(FATAL) << "call to miopenCreateActivationDescriptor failed: " - << ToString(status); + return absl::InternalError("call to miopenCreateActivationDescriptor failed: " + + ToString(status)); } else { switch (activation_mode) { case dnn::ActivationMode::kNone: - miopen_activation_mode_ = miopenActivationPASTHRU; + obj.miopen_activation_mode_ = miopenActivationPASTHRU; break; case dnn::ActivationMode::kSigmoid: - miopen_activation_mode_ = miopenActivationLOGISTIC; + obj.miopen_activation_mode_ = miopenActivationLOGISTIC; break; case dnn::ActivationMode::kRelu: - miopen_activation_mode_ = miopenActivationRELU; + case dnn::ActivationMode::kReluX: + obj.miopen_activation_mode_ = miopenActivationRELU; break; case dnn::ActivationMode::kRelu6: - miopen_activation_mode_ = miopenActivationRELU; - alpha_ = 6.0; + obj.miopen_activation_mode_ = miopenActivationRELU; + obj.alpha_ = 6.0; break; case dnn::ActivationMode::kTanh: - miopen_activation_mode_ = miopenActivationTANH; + obj.miopen_activation_mode_ = miopenActivationTANH; + break; + + case dnn::ActivationMode::kElu: + obj.miopen_activation_mode_ = miopenActivationELU; + break; + + case dnn::ActivationMode::kLeakyRelu: + obj.miopen_activation_mode_ = miopenActivationLEAKYRELU; break; + // Check with MIOpen re: support: kBandPass, kGeluExact default: - LOG(FATAL) << "Activation mode (" + VLOG(1) << "Activation mode (" << dnn::ActivationModeString(activation_mode) << ") not yet implemented"; - break; + return absl::InternalError("Activation not implemented"); } status = wrap::miopenSetActivationDescriptor( - handle_, miopen_activation_mode_, alpha_, beta_, gamma_); + obj.handle_, obj.miopen_activation_mode_, obj.alpha_, obj.beta_, obj.gamma_); if (status != miopenStatusSuccess) { - LOG(FATAL) << "call to miopenSetActivationDescriptor failed: " - << ToString(status); + return absl::InternalError("call to miopenSetActivationDescriptor failed: " + + ToString(status)); } } + return obj; } - - ~ScopedActivationDescriptor() { - auto status = wrap::miopenDestroyActivationDescriptor(handle_); - if (status != miopenStatusSuccess) { - LOG(FATAL) << "call to miopenDestroyActivationDescriptor failed: " - << ToString(status); - } + ScopedActivationDescriptor(ScopedActivationDescriptor&& other) + : ScopedDescriptor(std::move(other)) + { + miopen_activation_mode_ = other.miopen_activation_mode_; + alpha_ = other.alpha_; + beta_ = other.beta_; + gamma_ = other.gamma_; } - miopenActivationDescriptor_t handle() const { return handle_; } - uint64_t GetHashValue() { uint64_t hash_value = tsl::hash()(miopen_activation_mode_); hash_value = tsl::Hash64Combine(hash_value, tsl::hash()(alpha_)); @@ -1226,13 +1198,12 @@ class ScopedActivationDescriptor { return hash_value; } - private: - miopenActivationDescriptor_t handle_; // Owned. - - ScopedActivationDescriptor(const ScopedActivationDescriptor&) = delete; - void operator=(const ScopedActivationDescriptor&) = delete; + ScopedActivationDescriptor() + : miopen_activation_mode_(miopenActivationPASTHRU), + alpha_(0.0), + beta_(0.0), + gamma_(0.0) {} - public: // caching these values here to avoid calling miopenGetActivationDescriptor // to do the same. miopenGetActivationDescriptor gets called twice during each // call to execute a fusion plan (that involves the activation op)...once call @@ -1262,6 +1233,8 @@ class ScopedFusionPlanBase { } virtual ~ScopedFusionPlanBase() { + if(fusion_args_ == nullptr) + return; auto status = wrap::miopenDestroyOperatorArgs(fusion_args_); if (status != miopenStatusSuccess) { LOG(FATAL) << "call to miopenDestroyoperatorArgs failed: " @@ -1286,7 +1259,6 @@ class ScopedFusionPlanBase { bool CompilationSucceeded() { return fusion_plan_compiled_; } - protected: miopenStatus_t SetConvolutionArgs(const int op_idx, const float* alpha, const float* beta, const void* data) { miopenFusionOpDescriptor_t conv_op; @@ -1436,12 +1408,20 @@ class ScopedFusionPlanBase { } return status; } - +public: miopenHandle_t miopen_handle_; miopenFusionPlanDescriptor_t fusion_plan_; miopenOperatorArgs_t fusion_args_; // Owned. bool fusion_plan_compiled_; + ScopedFusionPlanBase(ScopedFusionPlanBase&& other) { + miopen_handle_ = other.miopen_handle_; + fusion_plan_ = other.fusion_plan_; + fusion_args_ = other.fusion_args_; + other.fusion_args_ = nullptr; + fusion_plan_compiled_ = other.fusion_plan_compiled_; + } + ScopedFusionPlanBase(const ScopedFusionPlanBase&) = delete; void operator=(const ScopedFusionPlanBase&) = delete; }; @@ -1449,87 +1429,101 @@ class ScopedFusionPlanBase { // class to represent the Convolution+Bias+Activation fusion plan class ScopedFusionPlanConvolutionBiasActivation : public ScopedFusionPlanBase { public: - ScopedFusionPlanConvolutionBiasActivation( + ScopedFusionPlanConvolutionBiasActivation(miopenHandle_t miopen_handle, + miopenTensorDescriptor_t input_descriptor) + : ScopedFusionPlanBase(miopen_handle, + miopenVerticalFusion, + input_descriptor) {} + + ScopedFusionPlanConvolutionBiasActivation(ScopedFusionPlanConvolutionBiasActivation&& other) + : ScopedFusionPlanBase(std::move(other)) { + conv_op = other.conv_op; + bias_op = other.bias_op; + actv_op = other.actv_op; + } + + static absl::StatusOr Create( miopenHandle_t miopen_handle, miopenTensorDescriptor_t input_descriptor, miopenTensorDescriptor_t filter_descriptor, miopenConvolutionDescriptor_t conv_descriptor, miopenTensorDescriptor_t bias_descriptor, - ScopedActivationDescriptor& activation_descriptor) - : ScopedFusionPlanBase(miopen_handle, miopenVerticalFusion, - input_descriptor) { + ScopedActivationDescriptor& act_descriptor) { + ScopedFusionPlanConvolutionBiasActivation obj(miopen_handle, input_descriptor); + + VLOG(2) << "Fusion Plan compile begin"; + uint64_t hash = GetFusionOpHashValue( miopen_handle, input_descriptor, filter_descriptor, conv_descriptor, - bias_descriptor, activation_descriptor); + bias_descriptor, act_descriptor); bool is_compiled = CachedFusionPlans::FindOrCreate( - hash, &fusion_plan_, miopenVerticalFusion, input_descriptor); + hash, &obj.fusion_plan_, miopenVerticalFusion, input_descriptor); + if(is_compiled) + VLOG(2) << "Cache hit"; if (!is_compiled) { - miopenFusionOpDescriptor_t conv_op; auto status = wrap::miopenCreateOpConvForward( - fusion_plan_, &conv_op, conv_descriptor, filter_descriptor); - if (status != miopenStatusSuccess) { - LOG(FATAL) << "call to miopenCreateOpConvForward failed: " - << ToString(status); - } + obj.fusion_plan_, &obj.conv_op, conv_descriptor, filter_descriptor); + if (status != miopenStatusSuccess) + return absl::InternalError("miopenCreateOpConvForward failed: " + + ToString(status)); - miopenFusionOpDescriptor_t bias_op; - status = wrap::miopenCreateOpBiasForward(fusion_plan_, &bias_op, + status = wrap::miopenCreateOpBiasForward(obj.fusion_plan_, &obj.bias_op, bias_descriptor); - if (status != miopenStatusSuccess) { - LOG(FATAL) << "call to miopenCreateOpBiasForward failed: " - << ToString(status); - } - - miopenFusionOpDescriptor_t actv_op; - status = wrap::miopenCreateOpActivationForward( - fusion_plan_, &actv_op, - activation_descriptor.miopen_activation_mode_); - if (status != miopenStatusSuccess) { - LOG(FATAL) << "call to miopenCreateOpActivationForward failed: " - << ToString(status); + if (status != miopenStatusSuccess) + return absl::InternalError("miopenCreateOpBiasForward failed: " + + ToString(status)); + + if(act_descriptor.miopen_activation_mode_ != miopenActivationPASTHRU) { + status = wrap::miopenCreateOpActivationForward( + obj.fusion_plan_, &obj.actv_op, + act_descriptor.miopen_activation_mode_); + if (status != miopenStatusSuccess) + return absl::InternalError( + "miopenCreateOpActivationForward failed: " + ToString(status)); } - status = wrap::miopenCompileFusionPlan(miopen_handle_, fusion_plan_); + status = wrap::miopenCompileFusionPlan(miopen_handle, obj.fusion_plan_); if (status != miopenStatusSuccess) { VLOG(2) << "call to miopenCompileFusionPlan (CBA) failed: " << ToString(status); CachedFusionPlans::MarkFusionPlanUnsupported(hash); } else { - VLOG(2) << "Fusion Plan compile succedded (CBA) "; - fusion_plan_compiled_ = true; + VLOG(2) << "Fusion Plan compile succeeded (CBA) "; + obj.fusion_plan_compiled_ = true; } } else { // fusion plan was already compiled...check whether it failed to compile - fusion_plan_compiled_ = !CachedFusionPlans::IsUnsupportedFusionPlan(hash); + obj.fusion_plan_compiled_ = !CachedFusionPlans::IsUnsupportedFusionPlan(hash); } + return obj; } miopenStatus_t SetConvolutionArgs(const void* filter_data) { - float alpha = 1.0; - float beta = 0.0; + static const float alpha = 1.0; + static const float beta = 0.0; return ScopedFusionPlanBase::SetConvolutionArgs(k_conv_op_idx, &alpha, &beta, filter_data); } miopenStatus_t SetBiasArgs(const void* bias_data) { - float alpha = 1.0; - float beta = 0.0; + static const float alpha = 1.0; + static const float beta = 0.0; return ScopedFusionPlanBase::SetBiasArgs(k_bias_op_idx, &alpha, &beta, bias_data); } miopenStatus_t SetActivationForwardArgs( ScopedActivationDescriptor& activation_descriptor) { - float alpha = 1.0; - float beta = 0.0; + static const float alpha = 1.0; + static const float beta = 0.0; return ScopedFusionPlanBase::SetActivationForwardArgs( k_actv_op_idx, &alpha, &beta, activation_descriptor.alpha_, activation_descriptor.beta_, activation_descriptor.gamma_); } - uint64_t GetFusionOpHashValue( + static uint64_t GetFusionOpHashValue( miopenHandle_t miopen_handle, miopenTensorDescriptor_t input_descriptor, miopenTensorDescriptor_t filter_descriptor, miopenConvolutionDescriptor_t conv_descriptor, @@ -1549,6 +1543,11 @@ class ScopedFusionPlanConvolutionBiasActivation : public ScopedFusionPlanBase { tsl::Hash64Combine(hash_value, activation_descriptor.GetHashValue()); return hash_value; } +public: + miopenFusionOpDescriptor_t conv_op; + miopenFusionOpDescriptor_t bias_op; + miopenFusionOpDescriptor_t actv_op; + private: const int k_conv_op_idx = 0; @@ -1718,8 +1717,8 @@ class ScopedFusionPlanBatchNormActivationForward : public ScopedFusionPlanBase { void* batch_mean, void* batch_var, void* saved_mean, void* saved_var, double epsilon) { - float alpha = 1.0; - float beta = 0.0; + static const float alpha = 1.0; + static const float beta = 0.0; return ScopedFusionPlanBase::SetBatchNormForwardArgs( k_batchnorm_op_idx, &alpha, &beta, scale, offset, batch_mean, batch_var, saved_mean, saved_var, /*exponential_average_factor=*/1.0, epsilon); @@ -1727,8 +1726,8 @@ class ScopedFusionPlanBatchNormActivationForward : public ScopedFusionPlanBase { miopenStatus_t SetActivationForwardArgs( ScopedActivationDescriptor& activation_descriptor) { - float alpha = 1.0; - float beta = 0.0; + static const float alpha = 1.0; + static const float beta = 0.0; return ScopedFusionPlanBase::SetActivationForwardArgs( k_actv_op_idx, &alpha, &beta, activation_descriptor.alpha_, @@ -1867,6 +1866,24 @@ class ScopedFusionPlanBatchNormActivationBackward }; namespace { + +const char* getTypeName(dnn::DataType data_type) { + switch (data_type) { + case dnn::DataType::kBF16: + return "BF16"; + case dnn::DataType::kFloat: + return "F32"; + case dnn::DataType::kHalf: + return "F16"; + case dnn::DataType::kInt8: + return "I8"; + case dnn::DataType::kDouble: + return "F64"; + default: + return "Unknown"; + } +} + miopenDataType_t ToMIOpenDataType( dnn::DataType data_type, dnn::DataLayout data_layout = dnn::DataLayout::kBatchDepthYX) { @@ -3245,22 +3262,23 @@ absl::Status MIOpenSupport::DoPrepareForConvolution( class RocmConvRunner : public dnn::ConvRunner { public: RocmConvRunner(GpuExecutor* parent, MIOpenAccess* miopen, int64_t algo_id, - size_t workspace_size, dnn::ConvolutionKind kind, - dnn::DataType input_type, bool use_immediate_mode, - BatchDescriptor input_descriptor, - BatchDescriptor output_descriptor, - FilterDescriptor filter_descriptor, - ConvolutionDescriptor conv_descriptor) + size_t workspace_size, dnn::ConvolutionKind kind, + dnn::DataType input_type, dnn::DataType output_type, + bool use_immediate_mode, + ScopedTensorDescriptor& scoped_input_desc, + ScopedTensorDescriptor& scoped_output_desc, + ScopedFilterDescriptor& scoped_filter_desc, + ScopedConvolutionDescriptor& scoped_conv_desc) : parent_(parent), miopen_(miopen), algo_id_(algo_id), workspace_size_(workspace_size), kind_(kind), use_immediate_mode_(use_immediate_mode), - input_desc_{input_descriptor, ToMIOpenDataType(input_type)}, - output_desc_{output_descriptor, ToMIOpenDataType(input_type)}, - filter_desc_{filter_descriptor, ToMIOpenDataType(input_type)}, - conv_desc_{conv_descriptor, ToMIOpenDataType(input_type)} { + input_desc_(std::move(scoped_input_desc)), + output_desc_(std::move(scoped_output_desc)), + filter_desc_(std::move(scoped_filter_desc)), + conv_desc_(std::move(scoped_conv_desc)) { bool is_backprop = ((kind == dnn::ConvolutionKind::BACKWARD_DATA) || (kind == dnn::ConvolutionKind::BACKWARD_FILTER)); // #if TF_ROCM_VERSION >= 50000 @@ -3439,13 +3457,11 @@ absl::Status MIOpenSupport::GetConvolveRunners( } std::vector profile_results; - if (!GetMIOpenConvolveAlgorithms( - kind, input_type, stream, input_descriptor, input_data, + if(!GetMIOpenConvolveAlgorithms( + kind, input_type, output_type, stream, input_descriptor, input_data, filter_descriptor, filter_data, output_descriptor, output_data, - convolution_descriptor, scratch_allocator, &profile_results)) { - return absl::UnknownError( - "GetConvolveRunners: GetMIOpenConvolveAlgorithms failed"); - } + convolution_descriptor, scratch_allocator, &profile_results)) + return absl::InternalError("GetMIOpenConvolveAlgorithms failure"); for (const auto& profile_result : profile_results) { TF_ASSIGN_OR_RETURN( @@ -3467,27 +3483,24 @@ MIOpenSupport::ConvolveRunnerFromDesc( const dnn::FilterDescriptor& filter_descriptor, const dnn::BatchDescriptor& output_descriptor, const dnn::ConvolutionDescriptor& convolution_descriptor) { - if (input_type != output_type) { - return absl::UnimplementedError( - absl::StrFormat("MIOpen backend does not support different input and " - "output types: %d != %d", - input_type, output_type)); - } - auto workspace_size = algorithm_desc.workspace_size(); - if (!workspace_size) { - return absl::InvalidArgumentError( - "MIOpenSupport::ConvolveRunnerFromDesc requires " - "AlgorithmProto.workspace_size, but it was missing."); - } + TF_ASSIGN_OR_RETURN(auto scoped_input_desc, scope(input_descriptor, ToMIOpenDataType(input_type))); + TF_ASSIGN_OR_RETURN(auto scoped_output_desc, scope(output_descriptor, ToMIOpenDataType(output_type))); + TF_ASSIGN_OR_RETURN(auto scoped_filter_desc, scope(filter_descriptor, ToMIOpenDataType(input_type))); + TF_ASSIGN_OR_RETURN(auto scoped_conv_desc, scope(convolution_descriptor)); + return {std::make_unique( parent_, miopen_.get(), algorithm_desc.algo_id(), *workspace_size, kind, - input_type, use_immediate_mode_, input_descriptor, output_descriptor, - filter_descriptor, convolution_descriptor)}; + input_type, output_type, use_immediate_mode_, + scoped_input_desc, + scoped_output_desc, + scoped_filter_desc, + scoped_conv_desc)}; } bool MIOpenSupport::GetMIOpenConvolveAlgorithms( - dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream, + dnn::ConvolutionKind kind, dnn::DataType input_type, dnn::DataType output_type, + Stream* stream, const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data, const dnn::FilterDescriptor& filter_descriptor, DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor, @@ -3497,19 +3510,20 @@ bool MIOpenSupport::GetMIOpenConvolveAlgorithms( std::vector* out_algorithms) { return use_immediate_mode_ ? GetMIOpenConvolveAlgorithmsImmediateMode( - kind, element_type, stream, input_descriptor, input_data, + kind, input_type, output_type, stream, input_descriptor, input_data, filter_descriptor, filter_data, output_descriptor, output_data, convolution_descriptor, scratch_allocator, - out_algorithms) + out_algorithms).ok() : GetMIOpenConvolveAlgorithmsFindMode( - kind, element_type, stream, input_descriptor, input_data, + kind, input_type, output_type, stream, input_descriptor, input_data, filter_descriptor, filter_data, output_descriptor, output_data, convolution_descriptor, scratch_allocator, - out_algorithms); + out_algorithms).ok(); } -bool MIOpenSupport::GetMIOpenConvolveAlgorithmsImmediateMode( - dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream, +absl::Status MIOpenSupport::GetMIOpenConvolveAlgorithmsImmediateMode( + dnn::ConvolutionKind kind, dnn::DataType input_type, dnn::DataType output_type, + Stream* stream, const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data, const dnn::FilterDescriptor& filter_descriptor, DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor, @@ -3519,14 +3533,13 @@ bool MIOpenSupport::GetMIOpenConvolveAlgorithmsImmediateMode( std::vector* out_algorithms) { auto miopen = miopen_->GetHandle(parent_, stream); - ScopedTensorDescriptor input_nd{input_descriptor, - ToMIOpenDataType(element_type)}; - ScopedTensorDescriptor output_nd{output_descriptor, - ToMIOpenDataType(element_type)}; - ScopedFilterDescriptor filter{filter_descriptor, - ToMIOpenDataType(element_type)}; - ScopedConvolutionDescriptor conv{convolution_descriptor, - ToMIOpenDataType(element_type)}; + TF_ASSIGN_OR_RETURN(auto input_nd, + scope(input_descriptor, ToMIOpenDataType(input_type))); + TF_ASSIGN_OR_RETURN(auto output_nd, + scope(output_descriptor, ToMIOpenDataType(output_type))); + TF_ASSIGN_OR_RETURN(auto filter, + scope(filter_descriptor, ToMIOpenDataType(input_type))); + TF_ASSIGN_OR_RETURN(auto conv, scope(convolution_descriptor)); bool is_backprop = ((kind == dnn::ConvolutionKind::BACKWARD_DATA) || (kind == dnn::ConvolutionKind::BACKWARD_FILTER)); @@ -3534,24 +3547,23 @@ bool MIOpenSupport::GetMIOpenConvolveAlgorithmsImmediateMode( // (call_context == dnn::CallContext::kBackpropFilter); #if TF_ROCM_VERSION >= 50000 - if (is_backprop && (ToMIOpenDataType(element_type) == miopenHalf)) { + if (is_backprop && (ToMIOpenDataType(input_type) == miopenHalf)) { wrap::miopenSetConvolutionAttribute( conv.handle(), MIOPEN_CONVOLUTION_ATTRIB_FP16_ALT_IMPL, 1); } #endif - // First determine the number of algorityhms available + // First determine the number of algorithms available size_t maxSolutionCount = 0; switch (kind) { - case dnn::ConvolutionKind::FORWARD: { + case dnn::ConvolutionKind::FORWARD: + case dnn::ConvolutionKind::FORWARD_BIAS_ACTIVATION: { auto status = wrap::miopenConvolutionForwardGetSolutionCount( miopen.handle(), filter.handle(), input_nd.handle(), conv.handle(), output_nd.handle(), &maxSolutionCount); if (status != miopenStatusSuccess) { - LOG(FATAL) - << "call to miopenConvolutionForwardGetSolutionCount failed: " - << ToString(status); - return false; + return absl::InternalError("call to miopenConvolutionForwardGetSolutionCount failed: " + + ToString(status)); } break; } @@ -3560,10 +3572,8 @@ bool MIOpenSupport::GetMIOpenConvolveAlgorithmsImmediateMode( miopen.handle(), output_nd.handle(), filter.handle(), conv.handle(), input_nd.handle(), &maxSolutionCount); if (status != miopenStatusSuccess) { - LOG(FATAL) << "call to miopenConvolutionBackwardDataGetSolutionCount " - "failed: " - << ToString(status); - return false; + return absl::InternalError("call to miopenConvolutionBackwardDataGetSolutionCount " + "failed: " + ToString(status)); } break; } @@ -3572,18 +3582,15 @@ bool MIOpenSupport::GetMIOpenConvolveAlgorithmsImmediateMode( miopen.handle(), output_nd.handle(), input_nd.handle(), conv.handle(), filter.handle(), &maxSolutionCount); if (status != miopenStatusSuccess) { - LOG(FATAL) - << "call to miopenConvolutionBackwardWeightsGetSolutionCount " - "failed: " - << ToString(status); - return false; + return absl::InternalError( + "call to miopenConvolutionBackwardWeightsGetSolutionCount " + "failed: " + ToString(status)); } break; } default: { - LOG(FATAL) << "Unexpected convolution kind " << static_cast(kind); - return false; - break; + return absl::InternalError("Unexpected convolution kind " + + std::to_string(static_cast(kind))); } } @@ -3608,9 +3615,8 @@ bool MIOpenSupport::GetMIOpenConvolveAlgorithmsImmediateMode( solutions.get()); if (status != miopenStatusSuccess) { - LOG(FATAL) << "call to miopenConvolutionForwardGetSolution failed: " - << ToString(status); - return false; + return absl::InternalError("call to miopenConvolutionForwardGetSolution failed: " + + ToString(status)); } VLOG(kConvDebugVlogLevel) @@ -3629,10 +3635,8 @@ bool MIOpenSupport::GetMIOpenConvolveAlgorithmsImmediateMode( output_nd.handle(), solution.solution_id); if (status != miopenStatusSuccess) { - LOG(FATAL) - << "call to miopenConvolutionForwardCompileSolution failed: " - << ToString(status); - return false; + return absl::InternalError("call to miopenConvolutionForwardCompileSolution failed: " + + ToString(status)); } out_algorithms->emplace_back( @@ -3646,10 +3650,8 @@ bool MIOpenSupport::GetMIOpenConvolveAlgorithmsImmediateMode( miopen.handle(), output_nd.handle(), filter.handle(), conv.handle(), input_nd.handle(), maxSolutionCount, &solutionCount, solutions.get()); if (status != miopenStatusSuccess) { - LOG(FATAL) - << "call to miopenConvolutionBackwardDataGetSolution failed: " - << ToString(status); - return false; + return absl::InternalError("call to miopenConvolutionBackwardDataGetSolution failed: " + + ToString(status)); } VLOG(kConvDebugVlogLevel) @@ -3668,10 +3670,8 @@ bool MIOpenSupport::GetMIOpenConvolveAlgorithmsImmediateMode( input_nd.handle(), solution.solution_id); if (status != miopenStatusSuccess) { - LOG(FATAL) << " call to miopenConvolutionBackwardDataCompileSolution " - "failed: " - << ToString(status); - return false; + return absl::InternalError(" call to miopenConvolutionBackwardDataCompileSolution " + "failed: " + ToString(status)); } out_algorithms->emplace_back( @@ -3684,10 +3684,8 @@ bool MIOpenSupport::GetMIOpenConvolveAlgorithmsImmediateMode( miopen.handle(), output_nd.handle(), input_nd.handle(), conv.handle(), filter.handle(), maxSolutionCount, &solutionCount, solutions.get()); if (status != miopenStatusSuccess) { - LOG(FATAL) - << "call to miopenConvolutionBackwardWeightsGetSolution failed: " - << ToString(status); - return false; + return absl::InternalError("call to miopenConvolutionBackwardWeightsGetSolution failed: " + + ToString(status)); } VLOG(kConvDebugVlogLevel) @@ -3706,11 +3704,8 @@ bool MIOpenSupport::GetMIOpenConvolveAlgorithmsImmediateMode( conv.handle(), filter.handle(), solution.solution_id); if (status != miopenStatusSuccess) { - LOG(FATAL) - << "call to miopenConvolutionBackwardWeightsCompileSolution " - "failed: " - << ToString(status); - return false; + return absl::InternalError("call to miopenConvolutionBackwardWeightsCompileSolution " + "failed: " + ToString(status)); } out_algorithms->emplace_back( @@ -3719,17 +3714,16 @@ bool MIOpenSupport::GetMIOpenConvolveAlgorithmsImmediateMode( break; } default: { - LOG(FATAL) << "Unexpected convolution kind " << static_cast(kind); - return false; - break; + return absl::InternalError("Unexpected convolution kind " + std::to_string(static_cast(kind))); } } - return true; + return absl::OkStatus(); } -bool MIOpenSupport::GetMIOpenConvolveAlgorithmsFindMode( - dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream, +absl::Status MIOpenSupport::GetMIOpenConvolveAlgorithmsFindMode( + dnn::ConvolutionKind kind, dnn::DataType input_type, dnn::DataType output_type, + Stream* stream, const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data, const dnn::FilterDescriptor& filter_descriptor, DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor, @@ -3739,14 +3733,13 @@ bool MIOpenSupport::GetMIOpenConvolveAlgorithmsFindMode( std::vector* out_algorithms) { auto miopen = miopen_->GetHandle(parent_, stream); - ScopedTensorDescriptor input_nd{input_descriptor, - ToMIOpenDataType(element_type)}; - ScopedTensorDescriptor output_nd{output_descriptor, - ToMIOpenDataType(element_type)}; - ScopedFilterDescriptor filter{filter_descriptor, - ToMIOpenDataType(element_type)}; - ScopedConvolutionDescriptor conv{convolution_descriptor, - ToMIOpenDataType(element_type)}; + TF_ASSIGN_OR_RETURN(auto input_nd, + scope(input_descriptor, ToMIOpenDataType(input_type))); + TF_ASSIGN_OR_RETURN(auto output_nd, + scope(output_descriptor, ToMIOpenDataType(output_type))); + TF_ASSIGN_OR_RETURN(auto filter, + scope(filter_descriptor, ToMIOpenDataType(input_type))); + TF_ASSIGN_OR_RETURN(auto conv, scope(convolution_descriptor)); bool is_backprop = ((kind == dnn::ConvolutionKind::BACKWARD_DATA) || (kind == dnn::ConvolutionKind::BACKWARD_FILTER)); @@ -3754,7 +3747,7 @@ bool MIOpenSupport::GetMIOpenConvolveAlgorithmsFindMode( // (call_context == dnn::CallContext::kBackpropFilter); #if TF_ROCM_VERSION >= 50000 - if (is_backprop && (ToMIOpenDataType(element_type) == miopenHalf)) { + if (is_backprop && (ToMIOpenDataType(input_type) == miopenHalf)) { wrap::miopenSetConvolutionAttribute( conv.handle(), MIOPEN_CONVOLUTION_ATTRIB_FP16_ALT_IMPL, 1); } @@ -3763,15 +3756,14 @@ bool MIOpenSupport::GetMIOpenConvolveAlgorithmsFindMode( // Determine the workspace memory size that will need by the call to Find size_t scratch_memory_size = 0; switch (kind) { - case dnn::ConvolutionKind::FORWARD: { + case dnn::ConvolutionKind::FORWARD: + case dnn::ConvolutionKind::FORWARD_BIAS_ACTIVATION: { auto status = wrap::miopenConvolutionForwardGetWorkSpaceSize( miopen.handle(), filter.handle(), input_nd.handle(), conv.handle(), output_nd.handle(), &scratch_memory_size); if (status != miopenStatusSuccess) { - LOG(FATAL) - << "call to miopenConvolutionForwardGetWorkspaceSize failed: " - << ToString(status); - return false; + return absl::InternalError("call to miopenConvolutionForwardGetWorkspaceSize failed: " + + ToString(status)); } break; } @@ -3780,10 +3772,8 @@ bool MIOpenSupport::GetMIOpenConvolveAlgorithmsFindMode( miopen.handle(), output_nd.handle(), filter.handle(), conv.handle(), input_nd.handle(), &scratch_memory_size); if (status != miopenStatusSuccess) { - LOG(FATAL) - << "call to miopenConvolutionBackwardDataGetWorkspaceSize failed: " - << ToString(status); - return false; + return absl::InternalError("call to miopenConvolutionBackwardDataGetWorkspaceSize failed: " + + ToString(status)); } break; } @@ -3792,17 +3782,13 @@ bool MIOpenSupport::GetMIOpenConvolveAlgorithmsFindMode( miopen.handle(), output_nd.handle(), input_nd.handle(), conv.handle(), filter.handle(), &scratch_memory_size); if (status != miopenStatusSuccess) { - LOG(FATAL) - << "call to miopenConvolutionBackwardWeightsGetWorkspaceSize " - "failed: " - << ToString(status); - return false; + return absl::InternalError("call to miopenConvolutionBackwardWeightsGetWorkspaceSize " + "failed: " + ToString(status)); } break; } default: { - LOG(FATAL) << "Unexpected convolution kind " << static_cast(kind); - return false; + return absl::InternalError("Unexpected convolution kind " + static_cast(kind)); break; } } @@ -3811,9 +3797,8 @@ bool MIOpenSupport::GetMIOpenConvolveAlgorithmsFindMode( DeviceMemory scratch_memory; if (scratch_memory_size != 0) { if (scratch_allocator == nullptr) { - LOG(FATAL) - << "An allocator must be specified when scratch memory is needed"; - return false; + return absl::InternalError("An allocator must be specified " + "when scratch memory is needed"); } auto allocated = scratch_allocator->AllocateBytes(scratch_memory_size); if (allocated.ok()) { @@ -3826,7 +3811,7 @@ bool MIOpenSupport::GetMIOpenConvolveAlgorithmsFindMode( "larger number (e.g. 8192) to increase the max memory limit.\n" << "\tIncreasing the max memory limit might help resolve this " "error"; - return false; + return absl::InternalError("Out of memory"); } } @@ -3842,7 +3827,8 @@ bool MIOpenSupport::GetMIOpenConvolveAlgorithmsFindMode( bool exhaustiveSearch = false; switch (kind) { - case dnn::ConvolutionKind::FORWARD: { + case dnn::ConvolutionKind::FORWARD: + case dnn::ConvolutionKind::FORWARD_BIAS_ACTIVATION: { auto status = wrap::miopenFindConvolutionForwardAlgorithm( miopen.handle(), input_nd.handle(), input_data.opaque(), filter.handle(), filter_data.opaque(), conv.handle(), @@ -3850,9 +3836,8 @@ bool MIOpenSupport::GetMIOpenConvolveAlgorithmsFindMode( &returnedAlgorithmCount, &returnedAlgorithm, scratch_memory.opaque(), scratch_memory_size, exhaustiveSearch); if (status != miopenStatusSuccess) { - LOG(FATAL) << "call to miopenFindConvolutionForwardAlgorithm failed: " - << ToString(status); - return false; + return absl::InternalError("call to miopenFindConvolutionForwardAlgorithm failed: " + + ToString(status)); } break; } @@ -3864,10 +3849,8 @@ bool MIOpenSupport::GetMIOpenConvolveAlgorithmsFindMode( &returnedAlgorithmCount, &returnedAlgorithm, scratch_memory.opaque(), scratch_memory_size, exhaustiveSearch); if (status != miopenStatusSuccess) { - LOG(FATAL) - << "call to miopenFindConvolutionBackwardDataAlgorithm failed: " - << ToString(status); - return false; + return absl::InternalError("call to miopenFindConvolutionBackwardDataAlgorithm failed: " + + ToString(status)); } break; } @@ -3879,16 +3862,13 @@ bool MIOpenSupport::GetMIOpenConvolveAlgorithmsFindMode( &returnedAlgorithmCount, &returnedAlgorithm, scratch_memory.opaque(), scratch_memory_size, exhaustiveSearch); if (status != miopenStatusSuccess) { - LOG(FATAL) << "call to miopenConvolutionBackwardWeightsAlgorithm " - "failed: " - << ToString(status); - return false; + return absl::InternalError("call to miopenConvolutionBackwardWeightsAlgorithm " + "failed: " + ToString(status)); } break; } default: { - LOG(FATAL) << "Unexpected convolution kind " << static_cast(kind); - return false; + return absl::InternalError("Unexpected convolution kind " + std::to_string(static_cast(kind))); break; } } @@ -3896,7 +3876,7 @@ bool MIOpenSupport::GetMIOpenConvolveAlgorithmsFindMode( out_algorithms->emplace_back( GetProfileResultFromConvAlgoPerf(kind, returnedAlgorithm)); - return true; + return absl::OkStatus(); } bool MIOpenSupport::GetRnnAlgorithms( @@ -3932,7 +3912,7 @@ bool MIOpenSupport::DoBatchNormalizationForward( stream, dnn::DataType::kBF16, dnn::DataType::kFloat, x, scale, offset, estimated_mean, estimated_variance, side_input, x_desc, scale_offset_desc, epsilon, exponential_average_factor, activation_mode, y, batch_mean, - batch_var, saved_mean, saved_inv_var, is_training); + batch_var, saved_mean, saved_inv_var, is_training).ok(); } bool MIOpenSupport::DoBatchNormalizationForward( @@ -3953,7 +3933,7 @@ bool MIOpenSupport::DoBatchNormalizationForward( stream, dnn::DataType::kHalf, dnn::DataType::kFloat, x, scale, offset, estimated_mean, estimated_variance, side_input, x_desc, scale_offset_desc, epsilon, exponential_average_factor, activation_mode, y, batch_mean, - batch_var, saved_mean, saved_inv_var, is_training); + batch_var, saved_mean, saved_inv_var, is_training).ok(); } bool MIOpenSupport::DoBatchNormalizationForward( @@ -3973,11 +3953,11 @@ bool MIOpenSupport::DoBatchNormalizationForward( stream, dnn::DataType::kFloat, dnn::DataType::kFloat, x, scale, offset, estimated_mean, estimated_variance, side_input, x_desc, scale_offset_desc, epsilon, exponential_average_factor, activation_mode, y, batch_mean, - batch_var, saved_mean, saved_inv_var, is_training); + batch_var, saved_mean, saved_inv_var, is_training).ok(); } template -bool MIOpenSupport::DoBatchNormalizationForwardImpl( +absl::Status MIOpenSupport::DoBatchNormalizationForwardImpl( Stream* stream, dnn::DataType input_data_type, dnn::DataType scale_data_type, const DeviceMemory& x, const DeviceMemory& scale, const DeviceMemory& offset, @@ -3992,10 +3972,10 @@ bool MIOpenSupport::DoBatchNormalizationForwardImpl( bool is_training) { auto miopen = miopen_->GetHandle(parent_, stream); - ScopedTensorDescriptor x_descriptor{x_desc, - ToMIOpenDataType(input_data_type)}; - ScopedTensorDescriptor scale_offset_descriptor{ - scale_offset_desc, ToMIOpenDataType(scale_data_type)}; + TF_ASSIGN_OR_RETURN(auto x_descriptor, + scope(x_desc, ToMIOpenDataType(input_data_type))); + TF_ASSIGN_OR_RETURN(auto scale_offset_descriptor, + scope(scale_offset_desc, ToMIOpenDataType(scale_data_type))); miopenBatchNormMode_t mode = miopenBNSpatial; float one = 1.0; float zero = 0.0; @@ -4018,11 +3998,10 @@ bool MIOpenSupport::DoBatchNormalizationForwardImpl( const_cast(maybe_inv_var), epsilon); } if (status != miopenStatusSuccess) { - LOG(ERROR) << "failed to enqueue forward batch normalization on stream: " - << ToString(status); - return false; + return absl::InternalError("failed to enqueue forward batch normalization on stream: " + + ToString(status)); } - return true; + return absl::OkStatus(); } bool MIOpenSupport::DoBatchNormalizationBackward( @@ -4041,7 +4020,7 @@ bool MIOpenSupport::DoBatchNormalizationBackward( return DoBatchNormalizationBackwardImpl( stream, miopenBFloat16, miopenFloat, y_backprop, x, scale, mean, inv_var, x_desc, scale_offset_desc, epsilon, x_backprop, scale_backprop, - offset_backprop); + offset_backprop).ok(); } bool MIOpenSupport::DoBatchNormalizationBackward( @@ -4059,7 +4038,7 @@ bool MIOpenSupport::DoBatchNormalizationBackward( return DoBatchNormalizationBackwardImpl( stream, miopenHalf, miopenFloat, y_backprop, x, scale, mean, inv_var, x_desc, scale_offset_desc, epsilon, x_backprop, scale_backprop, - offset_backprop); + offset_backprop).ok(); } bool MIOpenSupport::DoBatchNormalizationBackward( @@ -4077,11 +4056,11 @@ bool MIOpenSupport::DoBatchNormalizationBackward( return DoBatchNormalizationBackwardImpl( stream, miopenFloat, miopenFloat, y_backprop, x, scale, mean, variance, x_desc, scale_offset_desc, epsilon, x_backprop, scale_backprop, - offset_backprop); + offset_backprop).ok(); } template -bool MIOpenSupport::DoBatchNormalizationBackwardImpl( +absl::Status MIOpenSupport::DoBatchNormalizationBackwardImpl( Stream* stream, int miopen_input_type, int miopen_scale_type, const DeviceMemory& y_backprop, const DeviceMemory& x, const DeviceMemory& scale, const DeviceMemory& mean, @@ -4090,10 +4069,10 @@ bool MIOpenSupport::DoBatchNormalizationBackwardImpl( DeviceMemory* x_backprop, DeviceMemory* scale_backprop, DeviceMemory* offset_backprop) { auto miopen = miopen_->GetHandle(parent_, stream); - ScopedTensorDescriptor x_descriptor{ - x_desc, static_cast(miopen_input_type)}; - ScopedTensorDescriptor scale_offset_descriptor{ - scale_offset_desc, static_cast(miopen_scale_type)}; + TF_ASSIGN_OR_RETURN(auto x_descriptor, + scope(x_desc, static_cast(miopen_input_type))); + TF_ASSIGN_OR_RETURN(auto scale_offset_descriptor, + scope(scale_offset_desc, static_cast(miopen_scale_type))); miopenBatchNormMode_t mode = miopenBNSpatial; float one = 1.0; float zero = 0.0; @@ -4106,17 +4085,19 @@ bool MIOpenSupport::DoBatchNormalizationBackwardImpl( scale_backprop->opaque(), offset_backprop->opaque(), epsilon, mean.opaque(), variance.opaque()); if (status != miopenStatusSuccess) { - LOG(ERROR) << "failed to enqueue backward batch normalization on stream: " - << ToString(status); - return false; + return absl::InternalError("failed to enqueue backward batch normalization on stream: " + + ToString(status)); } - return true; + return absl::OkStatus(); } -template +template void launchInplaceBiasActivation(hipStream_t stream, void* c_data, - const void* bias_data, int activation_mode, - uint64_t m, uint64_t n, int64_t ldc, + const void* bias_data, + const void* side_input_data, + float side_input_scale, + int activation_mode, + uint64_t batch, uint64_t m, uint64_t n, int64_t ldc, float param); class ROCmFusedMatmulRunner : public dnn::FusedMatmulRunner { @@ -4212,31 +4193,39 @@ absl::Status ROCmFusedMatmulRunner::gemm(Stream* stream, NumericOptions{}, blas::CallContext::kNone); } -template +template absl::Status InplaceBiasActivation(Stream* stream, DeviceMemoryBase c_data, DeviceMemoryBase bias_data, + DeviceMemoryBase side_input_data, + float side_input_scale, dnn::ActivationMode activation_mode, - uint64_t m, uint64_t n, int64_t ldc, - float param) { + uint64_t batch, uint64_t m, uint64_t n, int64_t ldc, + float param, + bool transpose=false) { typedef typename std::conditional< std::is_same_v, __half, typename std::conditional, hip_bfloat16, T>::type>::type CT; - - if (activation_mode == dnn::ActivationMode::kReluX || - activation_mode == dnn::ActivationMode::kBandPass || - activation_mode == dnn::ActivationMode::kLeakyRelu) - - return absl::InvalidArgumentError( - "ROCm InplaceBiasActivation can't be used with " - "parametric activations yet"); - - launchInplaceBiasActivation( + typedef typename std::conditional< + std::is_same_v, __half, + typename std::conditional, + hip_bfloat16, Tbias>::type>::type CTbias; + launchInplaceBiasActivation( AsGpuStreamValue(stream), c_data.opaque(), bias_data.opaque(), - static_cast(activation_mode), m, n, ldc, param); + side_input_data.opaque(), side_input_scale, + static_cast(activation_mode)+(transpose?10:0), batch, m, n, ldc, param); return absl::OkStatus(); } +template +absl::Status InplaceBiasActivation(Stream* stream, + DeviceMemory c_data, + DeviceMemory bias_data, + Args... args) { + return InplaceBiasActivation(stream, + DeviceMemoryBase(c_data), DeviceMemoryBase(bias_data), args...); +} + // Launch the operation, with the signature determined by `Sig`. absl::Status ROCmFusedMatmulRunner::operator()( Stream* stream, dnn::ProfileResult* prof, DeviceMemoryBase scratch_memory, @@ -4255,18 +4244,20 @@ absl::Status ROCmFusedMatmulRunner::operator()( return absl::InvalidArgumentError("Unsupported input type"); if (!status.ok()) return status; + + DeviceMemory side_input; if (_input_type == dnn::DataType::kFloat) - return InplaceBiasActivation(stream, c_data, bias_data, - _activation_mode, _m, _n, _ldc, 0.0f); + return InplaceBiasActivation(stream, c_data, bias_data, side_input, 0.0f, + _activation_mode, 1, _m, _n, _ldc, 0.0f); else if (_input_type == dnn::DataType::kHalf) return InplaceBiasActivation( - stream, c_data, bias_data, _activation_mode, _m, _n, _ldc, 0.0f); + stream, c_data, bias_data, side_input, 0.0f, _activation_mode, 1, _m, _n, _ldc, 0.0f); else if (_input_type == dnn::DataType::kBF16) return InplaceBiasActivation( - stream, c_data, bias_data, _activation_mode, _m, _n, _ldc, 0.0f); + stream, c_data, bias_data, side_input, 0.0f, _activation_mode, 1, _m, _n, _ldc, 0.0f); else if (_input_type == dnn::DataType::kDouble) - return InplaceBiasActivation(stream, c_data, bias_data, - _activation_mode, _m, _n, _ldc, 0.0f); + return InplaceBiasActivation(stream, c_data, bias_data, side_input, 0.0f, + _activation_mode, 1, _m, _n, _ldc, 0.0f); else return absl::InvalidArgumentError("Unsupported input type"); } @@ -4344,9 +4335,9 @@ absl::Status MIOpenSupport::DoPoolForward( auto miopen_dtype = element_type == dnn::DataType::kFloat ? miopenFloat : miopenHalf; - ScopedTensorDescriptor src_desc{input_dimensions, miopen_dtype}; - ScopedTensorDescriptor dest_desc{output_dimensions, miopen_dtype}; - ScopedPoolingDescriptor pooling_desc{pooling_dimensions}; + TF_ASSIGN_OR_RETURN(auto src_desc, scope(input_dimensions, miopen_dtype)); + TF_ASSIGN_OR_RETURN(auto dest_desc, scope(output_dimensions, miopen_dtype)); + TF_ASSIGN_OR_RETURN(auto pooling_desc, scope(pooling_dimensions)); bool do_backward = false; uint8* workspace = nullptr; @@ -4498,9 +4489,9 @@ absl::Status MIOpenSupport::DoPoolBackward( auto miopen_dtype = element_type == dnn::DataType::kFloat ? miopenFloat : miopenHalf; - ScopedTensorDescriptor src_desc{input_dimensions, miopen_dtype}; - ScopedTensorDescriptor dest_desc{output_dimensions, miopen_dtype}; - ScopedPoolingDescriptor pooling_desc{pooling_dimensions}; + TF_ASSIGN_OR_RETURN(auto src_desc, scope(input_dimensions, miopen_dtype)); + TF_ASSIGN_OR_RETURN(auto dest_desc, scope(output_dimensions, miopen_dtype)); + TF_ASSIGN_OR_RETURN(auto pooling_desc, scope(pooling_dimensions)); uint8* workspace_ptr = 0; DeviceMemory workspace; @@ -4588,6 +4579,17 @@ absl::Status MIOpenSupport::DoPoolBackward( return absl::OkStatus(); } +#define ASSIGN_OR_RETURN_FALSE(lhs, rexpr) \ + ASSIGN_OR_RETURN_FALSE_IMPL( \ + TF_STATUS_MACROS_CONCAT_NAME(_status_or_value, __COUNTER__), lhs, rexpr) + +#define ASSIGN_OR_RETURN_FALSE_IMPL(statusor, lhs, rexpr) \ + auto statusor = (rexpr); \ + if (TF_PREDICT_FALSE(!statusor.ok())) { \ + return false; \ + } \ + lhs = std::move(statusor).value() + bool MIOpenSupport::DoNormalizeWithDimensions( Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor, const dnn::BatchDescriptor& dimensions, @@ -4605,8 +4607,8 @@ bool MIOpenSupport::DoNormalizeWithDimensions( auto miopen = miopen_->GetHandle(parent_, stream); // Launch the normalization. - ScopedTensorDescriptor dims{dimensions, miopenFloat}; - ScopedNormalizeDescriptor normalize{normalize_descriptor}; + ASSIGN_OR_RETURN_FALSE(auto dims, scope(dimensions, miopenFloat)); + ASSIGN_OR_RETURN_FALSE(auto normalize, scope(normalize_descriptor)); // Alpha is the scaling factor for input. float alpha = 1.0f; @@ -4643,8 +4645,8 @@ bool MIOpenSupport::DoNormalizeBackwardWithDimensions( auto miopen = miopen_->GetHandle(parent_, stream); - ScopedTensorDescriptor dims{dimensions, miopenFloat}; - ScopedNormalizeDescriptor normalize{normalize_descriptor}; + ASSIGN_OR_RETURN_FALSE(auto dims, scope(dimensions, miopenFloat)); + ASSIGN_OR_RETURN_FALSE(auto normalize, scope(normalize_descriptor)); float alpha = 1.0f; float beta = 0.0f; @@ -4728,9 +4730,9 @@ bool MIOpenSupport::DeriveOutputBatchDescriptor( const FilterDescriptor& filter_descriptor, const dnn::ConvolutionDescriptor& convolution_descriptor, dnn::BatchDescriptor* output_batch_descriptor) { - ScopedTensorDescriptor input_nd{batch_descriptor, miopenFloat}; - ScopedFilterDescriptor filter{filter_descriptor, miopenFloat}; - ScopedConvolutionDescriptor conv{convolution_descriptor, miopenFloat}; + ASSIGN_OR_RETURN_FALSE(auto input_nd, scope(batch_descriptor, miopenFloat)); + ASSIGN_OR_RETURN_FALSE(auto filter, scope(filter_descriptor, miopenFloat)); + ASSIGN_OR_RETURN_FALSE(auto conv, scope(convolution_descriptor)); int dn = batch_descriptor.ndims() + 2; std::vector dims(dn); // in BDYX @@ -4754,6 +4756,388 @@ bool MIOpenSupport::DeriveOutputBatchDescriptor( return true; } +class RocmFusedConvRunner : public dnn::FusedConvRunner { +public: + std::string ToString() const override { + return MakeAlgorithmDesc().ToString(); + } + + uint64_t GetWorkspaceSize() const override { return workspace_size_; } + + absl::StatusOr ToAlgorithmDesc() const override { + return MakeAlgorithmDesc(); + } + + absl::Status operator()(Stream* stream, dnn::ProfileResult* profile_result, + DeviceMemoryBase scratch_memory, + DeviceMemoryBase input_data, + DeviceMemoryBase filter_data, + DeviceMemoryBase side_input_data, + DeviceMemoryBase bias_data, + DeviceMemoryBase output_data) const override { + VLOG(2) << "RocmFusedConvRunner()"; + if (parent_ != stream->parent()) { + return absl::InternalError( + "RocmFusedConvRunner cached across multiple StreamExecutors."); + } + + // We can't reliably detect whether this sequence can be fused until + // we come here and actually try to fuse it. So, we need a fallback. + bool do_unfused = (side_input_scale_ != 0.0) + || !fusion_plan_.CompilationSucceeded(); + + if(do_unfused) + return execute_unfused(stream, profile_result, + scratch_memory, input_data, filter_data, side_input_data, + bias_data, output_data); + auto algo = MakeAlgorithmDesc(); + auto miopen = miopen_->GetHandle(parent_, stream); + fusion_plan_.SetConvolutionArgs(filter_data.opaque()); + fusion_plan_.SetBiasArgs(bias_data.opaque()); + if (activation_desc_.miopen_activation_mode_ != miopenActivationPASTHRU) + fusion_plan_.SetActivationForwardArgs(activation_desc_); + + std::optional timer; + if (profile_result) { + auto timer_or_status = GpuTimer::Create(AsGpuStream(stream)); + if (!timer_or_status.ok()) { + LOG(ERROR) << "Failed to create timer"; + return absl::InternalError("Failed to start timer"); + } + timer.emplace(std::move(*timer_or_status)); + } + + miopenStatus_t status; + status = wrap::miopenExecuteFusionPlan( + miopen.handle(), fusion_plan_.fusion_plan_, input_nd_.handle(), + input_data.opaque(), output_nd_.handle(), output_data.opaque(), + fusion_plan_.fusion_args_); + + if (status != miopenStatusSuccess) { + LOG(ERROR) << "Failed to enqueue fused convolution on stream: " + << stream_executor::gpu::ToString(status); + return absl::InternalError( + "Failed to enqueue fused convolution on stream: " + + stream_executor::gpu::ToString(status)); + } + + if (profile_result) { + absl::StatusOr elapsed = timer->GetElapsedDuration(); + if (!elapsed.ok()) { + LOG(ERROR) << "Failed to get elapsed duration"; + return absl::InternalError("Timer failure"); + } + profile_result->set_elapsed_time_in_ms( + absl::ToDoubleMilliseconds(*elapsed)); + profile_result->set_algorithm(algo); + profile_result->set_scratch_size(scratch_memory.size()); + } + + return absl::OkStatus(); + } + + public: + // Queries the workspace size and constructs a 'RocmFusedConvRunner'. + static absl::StatusOr> Create( + GpuExecutor* parent, Stream* stream, MIOpenAccess* miopen, + const dnn::AlgorithmDesc& algo, + dnn::DataType input_type, + dnn::DataType bias_type, + double conv_scale, double side_input_scale, + double leakyrelu_alpha, + BatchDescriptor input_nd, BatchDescriptor output_nd, + FilterDescriptor filter, BatchDescriptor bias_nd, + ConvolutionDescriptor conv, + dnn::ActivationMode activation + ) { + TF_ASSIGN_OR_RETURN(auto input_nd_, scope(input_nd, ToMIOpenDataType(input_type, input_nd.layout()))); + TF_ASSIGN_OR_RETURN(auto output_nd_, scope(output_nd, ToMIOpenDataType(input_type, input_nd.layout()))); + TF_ASSIGN_OR_RETURN(auto filter_, scope(filter, ToMIOpenDataType(input_type))); + TF_ASSIGN_OR_RETURN(auto bias_nd_, scope(bias_nd, ToMIOpenDataType(bias_type))); + TF_ASSIGN_OR_RETURN(auto conv_, scope(conv)); + + TF_ASSIGN_OR_RETURN(auto activation_desc, + ScopedActivationDescriptor::Create(activation, leakyrelu_alpha)); + + TF_ASSIGN_OR_RETURN(auto fusion_plan, + ScopedFusionPlanConvolutionBiasActivation::Create( + miopen->GetHandle(parent, stream).handle(), + input_nd_.handle(), filter_.handle(), + conv_.handle(), bias_nd_.handle(), activation_desc)); + + VLOG(2) << "RocmFusedConvRunner"; + auto mi = miopen->GetHandle(parent, stream); + + size_t maxSolutionCount = 0; + auto status = wrap::miopenConvolutionForwardGetSolutionCount( + mi.handle(), filter_.handle(), input_nd_.handle(), conv_.handle(), + output_nd_.handle(), &maxSolutionCount); + + size_t solutionCount = 0; + std::unique_ptr solutions( + new miopenConvSolution_t[maxSolutionCount]); + + status = wrap::miopenConvolutionForwardGetSolution( + mi.handle(), filter_.handle(), input_nd_.handle(), conv_.handle(), + output_nd_.handle(), maxSolutionCount, &solutionCount, + solutions.get()); + + VLOG(2) << solutionCount << " solutions"; + + if(solutionCount==0) + return absl::InternalError("No algorithms found"); + + size_t workspace_size_1 = solutions[0].workspace_size; + size_t true_workspace_size = 0; + status = wrap::miopenConvolutionForwardGetWorkSpaceSize( + mi.handle(), filter_.handle(), input_nd_.handle(), conv_.handle(), + output_nd_.handle(), &true_workspace_size); + + VLOG(2) << "True workspace size " << workspace_size_1 << " " << true_workspace_size; + + auto obj = new RocmFusedConvRunner(parent, stream, miopen, + int64_t (solutions[0].solution_id), + true_workspace_size, + input_type, + bias_type, + conv_scale, side_input_scale, + leakyrelu_alpha, + input_nd, output_nd, + filter, bias_nd, + conv, + activation, + input_nd_, + output_nd_, + filter_, + bias_nd_, + conv_, + activation_desc, + fusion_plan); + + return std::unique_ptr(obj); + } + + private: + // Private to prevent passing in the wrong workspace_size. + RocmFusedConvRunner(GpuExecutor* parent, Stream* stream, MIOpenAccess* miopen, + int64_t algo_id, + size_t workspace_size, + dnn::DataType input_type, + dnn::DataType bias_type, + double conv_scale, double side_input_scale, + double leakyrelu_alpha, + BatchDescriptor dnn_input_nd, BatchDescriptor dnn_output_nd, + FilterDescriptor dnn_filter, BatchDescriptor dnn_bias_nd, + ConvolutionDescriptor dnn_conv, + dnn::ActivationMode activation, + ScopedTensorDescriptor& input_nd, + ScopedTensorDescriptor& output_nd, + ScopedFilterDescriptor& filter, + ScopedTensorDescriptor& bias_nd, + ScopedConvolutionDescriptor& conv, + ScopedActivationDescriptor& activation_desc, + ScopedFusionPlanConvolutionBiasActivation& fusion_plan + ) + : parent_(parent), + miopen_(miopen), + algo_id_(algo_id), + workspace_size_(workspace_size), + input_type_(input_type), + bias_type_(bias_type), + + conv_scale_(conv_scale), + side_input_scale_(side_input_scale), + leakyrelu_alpha_(leakyrelu_alpha), + side_input_scale_f32_(float(side_input_scale)), + + dnn_input_nd_(dnn_input_nd), + dnn_output_nd_(dnn_output_nd), + dnn_filter_(dnn_filter), + dnn_bias_nd_(dnn_bias_nd), + dnn_conv_(dnn_conv), + + activation_mode_(activation), + + input_nd_(std::move(input_nd)), + output_nd_(std::move(output_nd)), + filter_(std::move(filter)), + bias_nd_(std::move(bias_nd)), + conv_(std::move(conv)), + activation_desc_(std::move(activation_desc)), + fusion_plan_(std::move(fusion_plan)) {} + + absl::Status execute_unfused(Stream* stream, dnn::ProfileResult* profile_result, + DeviceMemoryBase scratch_memory, + DeviceMemoryBase input_data, + DeviceMemoryBase filter_data, + DeviceMemoryBase side_input_data, + DeviceMemoryBase bias_data, + DeviceMemoryBase output_data) const { + auto miopen = miopen_->GetHandle(parent_, stream); + auto status = wrap::miopenConvolutionForwardImmediate( + miopen.handle(), filter_.handle(), filter_data.opaque(), + input_nd_.handle(), input_data.opaque(), conv_.handle(), + output_nd_.handle(), output_data.opaque(), + scratch_memory.opaque(), scratch_memory.size(), + static_cast(algo_id_)); + if (status != miopenStatusSuccess) { + VLOG(0) << "Failed to enqueue convolution: " + << stream_executor::gpu::ToString(status); + return absl::InternalError( + "Failed to enqueue convolution: " + + stream_executor::gpu::ToString(status)); + } + + int batch; + std::vector dims_output = dnn_output_nd_.full_dims(dnn_output_nd_.layout()); + int rank = dims_output.size(); + if(rank != 4 && rank != 5) + return absl::InternalError("RocmFusedConvRunner expects 4d or 5d descriptors"); + int d1 = 1, d2 = 1; + bool bNCHW = (dnn_output_nd_.layout() != dnn::DataLayout::kBatchYXDepth); + batch = dims_output[0]; + int w, h=1; + if(bNCHW) { + d1 = dims_output[1]; + for(int i=2; i(output_data), DeviceMemory(bias_data)); + else if (input_type_ == dnn::DataType::kHalf && bias_type_ == dnn::DataType::kFloat) + biasActStatus = inplace_call(DeviceMemory(output_data), DeviceMemory(bias_data)); + else if (input_type_ == dnn::DataType::kHalf && bias_type_ == dnn::DataType::kHalf) + biasActStatus = inplace_call(DeviceMemory(output_data), DeviceMemory(bias_data)); + else if (input_type_ == dnn::DataType::kBF16 && bias_type_ == dnn::DataType::kFloat) + biasActStatus = inplace_call(DeviceMemory(output_data), DeviceMemory(bias_data)); + else if (input_type_ == dnn::DataType::kBF16 && bias_type_ == dnn::DataType::kBF16) + biasActStatus = inplace_call(DeviceMemory(output_data), DeviceMemory(bias_data)); + else + return absl::InternalError("Unsupported data type"); + + return absl::OkStatus(); + } + + // Internal form of ToAlgorithmDesc without the StatusOr. + dnn::AlgorithmDesc MakeAlgorithmDesc() const { + return {algo_id_, /*tensor_ops_enabled_*/ true, workspace_size_}; + } + + std::string desc_; + + GpuExecutor* parent_; + MIOpenAccess* miopen_; + int64_t algo_id_; + size_t workspace_size_; + dnn::DataType input_type_, bias_type_; + double conv_scale_, side_input_scale_, leakyrelu_alpha_; + float side_input_scale_f32_; + dnn::ActivationMode activation_mode_; + + BatchDescriptor dnn_input_nd_; + BatchDescriptor dnn_output_nd_; + FilterDescriptor dnn_filter_; + BatchDescriptor dnn_bias_nd_; + ConvolutionDescriptor dnn_conv_; + + ScopedTensorDescriptor input_nd_; + ScopedTensorDescriptor output_nd_; + ScopedFilterDescriptor filter_; + ScopedTensorDescriptor bias_nd_; + ScopedConvolutionDescriptor conv_; + mutable ScopedActivationDescriptor activation_desc_; + mutable ScopedFusionPlanConvolutionBiasActivation fusion_plan_; +}; + + +absl::StatusOr> +MIOpenSupport::FusedConvolveRunnerFromDesc( + Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, + dnn::ConvolutionKind kind, dnn::DataType input_type, + dnn::DataType bias_type, dnn::DataType output_type, double conv_scale, + double side_input_scale, double leakyrelu_alpha, + const dnn::BatchDescriptor& input_descriptor, + const dnn::FilterDescriptor& filter_descriptor, + const dnn::BatchDescriptor& bias_descriptor, + const dnn::BatchDescriptor& output_descriptor, + const dnn::ConvolutionDescriptor& convolution_descriptor, + dnn::ActivationMode activation_mode) { + + VLOG(2) << "MIOpenSupport::FusedConvolveRunnerFromDesc " + << filter_descriptor.ndims() << " " + << side_input_scale << " " + << convolution_descriptor.ToString() + << getTypeName(input_type)<< " " + << getTypeName(bias_type) << " " + << getTypeName(output_type); + + // note: these checks need to be duplicated in XLA logic, because XLA calls + // this function directly and it terminates the process on error + + return RocmFusedConvRunner::Create( + parent_, stream, miopen_.get(), algorithm_desc, input_type, bias_type, + conv_scale, side_input_scale, leakyrelu_alpha, + input_descriptor, output_descriptor, filter_descriptor, bias_descriptor, + convolution_descriptor, activation_mode); +} + +absl::Status MIOpenSupport::GetFusedConvolveRunners( + bool use_cudnn_frontend, dnn::ConvolutionKind kind, + dnn::DataType input_type, dnn::DataType bias_type, + dnn::DataType output_type, double conv_scale, + double side_input_scale, double leakyrelu_alpha, Stream* stream, + const dnn::BatchDescriptor& input_descriptor, + const dnn::FilterDescriptor& filter_descriptor, + const dnn::BatchDescriptor& bias_descriptor, + const dnn::BatchDescriptor& output_descriptor, + const dnn::ConvolutionDescriptor& convolution_descriptor, + bool use_fallback, dnn::ActivationMode activation_mode, + const NumericOptions& numeric_options, + std::vector>* out_exec_plans) { + + VLOG(2) << "MIOpenSupport::GetFusedConvolveRunners"; + VLOG(2) << "filter_descriptor " << filter_descriptor.ndims(); + + std::vector algorithms{ + // clang-format off + dnn::AlgorithmDesc(miopenConvolutionFwdAlgoGEMM, false, 0), + dnn::AlgorithmDesc(miopenConvolutionFwdAlgoDirect, false, 0), + dnn::AlgorithmDesc(miopenConvolutionFwdAlgoFFT, false, 0), + dnn::AlgorithmDesc(miopenConvolutionFwdAlgoWinograd, false, 0), + // clang-format on + }; + + for (const auto& algo : algorithms) { + auto runner_or = FusedConvolveRunnerFromDesc( + stream, algo, kind, input_type, bias_type, output_type, conv_scale, + side_input_scale, leakyrelu_alpha, input_descriptor, + filter_descriptor, bias_descriptor, output_descriptor, + convolution_descriptor, activation_mode); + if (!runner_or.ok()) + continue; + out_exec_plans->push_back(std::move(runner_or).value()); + } + + VLOG(2) << "MIOpenSupport::GetFusedConvolveRunners returns " << out_exec_plans->size() << " runners"; + return absl::OkStatus(); +} + bool UseNhwcLayoutForRocm() { #if TF_ROCM_VERSION >= 50100 static bool is_enabled = [] { diff --git a/xla/stream_executor/rocm/rocm_dnn.h b/xla/stream_executor/rocm/rocm_dnn.h index ecaffd3cad392..3194be1b384ed 100644 --- a/xla/stream_executor/rocm/rocm_dnn.h +++ b/xla/stream_executor/rocm/rocm_dnn.h @@ -251,8 +251,33 @@ class MIOpenSupport : public dnn::DnnSupport { const dnn::BatchDescriptor& output_descriptor, const dnn::ConvolutionDescriptor& convolution_descriptor) override; + absl::StatusOr> FusedConvolveRunnerFromDesc( + Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, + dnn::ConvolutionKind kind, dnn::DataType input_type, + dnn::DataType bias_type, dnn::DataType output_type, double conv_scale, + double side_input_scale, double leakyrelu_alpha, + const dnn::BatchDescriptor& input_descriptor, + const dnn::FilterDescriptor& filter_descriptor, + const dnn::BatchDescriptor& bias_descriptor, + const dnn::BatchDescriptor& output_descriptor, + const dnn::ConvolutionDescriptor& convolution_descriptor, + dnn::ActivationMode activation_mode) override; + bool GetMIOpenConvolveAlgorithms( - dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream, + dnn::ConvolutionKind kind, dnn::DataType input_type, dnn::DataType output_type, + Stream* stream, + const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data, + const dnn::FilterDescriptor& filter_descriptor, + DeviceMemoryBase filter_data, + const dnn::BatchDescriptor& output_descriptor, + DeviceMemoryBase output_data, + const dnn::ConvolutionDescriptor& convolution_descriptor, + ScratchAllocator* scratch_allocator, + std::vector* out_algorithms); + + absl::Status GetMIOpenConvolveAlgorithmsImmediateMode( + dnn::ConvolutionKind kind, dnn::DataType input_type, dnn::DataType output_type, + Stream* stream, const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data, const dnn::FilterDescriptor& filter_descriptor, DeviceMemoryBase filter_data, @@ -260,7 +285,19 @@ class MIOpenSupport : public dnn::DnnSupport { DeviceMemoryBase output_data, const dnn::ConvolutionDescriptor& convolution_descriptor, ScratchAllocator* scratch_allocator, - std::vector* out_algorithms) override; + std::vector* out_algorithms); + + absl::Status GetMIOpenConvolveAlgorithmsFindMode( + dnn::ConvolutionKind kind, dnn::DataType input_type, dnn::DataType output_type, + Stream* stream, + const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data, + const dnn::FilterDescriptor& filter_descriptor, + DeviceMemoryBase filter_data, + const dnn::BatchDescriptor& output_descriptor, + DeviceMemoryBase output_data, + const dnn::ConvolutionDescriptor& convolution_descriptor, + ScratchAllocator* scratch_allocator, + std::vector* out_algorithms); bool GetRnnAlgorithms( std::vector* out_algorithms) override; @@ -389,6 +426,20 @@ class MIOpenSupport : public dnn::DnnSupport { std::vector>* out_exec_plans) override; + absl::Status GetFusedConvolveRunners( + bool use_cudnn_frontend, dnn::ConvolutionKind kind, + dnn::DataType input_type, dnn::DataType bias_type, + dnn::DataType output_type, double conv_scale, double side_input_scale, + double leakyrelu_alpha, Stream* stream, + const dnn::BatchDescriptor& input_descriptor, + const dnn::FilterDescriptor& filter_descriptor, + const dnn::BatchDescriptor& bias_descriptor, + const dnn::BatchDescriptor& output_descriptor, + const dnn::ConvolutionDescriptor& convolution_descriptor, bool use_fallback, + dnn::ActivationMode activation_mode, + const NumericOptions& numeric_options, + std::vector>* out_exec_plans) override; + absl::Status DoPoolForward(dnn::DataType element_type, Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions, const dnn::BatchDescriptor& input_dimensions, @@ -469,7 +520,7 @@ class MIOpenSupport : public dnn::DnnSupport { bool m_pooling_cache_enabled = false; template - bool DoBatchNormalizationForwardImpl( + absl::Status DoBatchNormalizationForwardImpl( Stream* stream, dnn::DataType input_data_type, dnn::DataType scale_data_type, const DeviceMemory& x, const DeviceMemory& scale, const DeviceMemory& offset, @@ -484,7 +535,7 @@ class MIOpenSupport : public dnn::DnnSupport { bool is_training); template - bool DoBatchNormalizationBackwardImpl( + absl::Status DoBatchNormalizationBackwardImpl( Stream* stream, int miopen_input_type, int miopen_scale_type, const DeviceMemory& y_backprop, const DeviceMemory& x, const DeviceMemory& scale, const DeviceMemory& mean, @@ -569,28 +620,6 @@ class MIOpenSupport : public dnn::DnnSupport { ScratchAllocator* scratch_allocator, DeviceMemory* scratch_memory, int* ctc_loss_algo_id) override; - bool GetMIOpenConvolveAlgorithmsImmediateMode( - dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream, - const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data, - const dnn::FilterDescriptor& filter_descriptor, - DeviceMemoryBase filter_data, - const dnn::BatchDescriptor& output_descriptor, - DeviceMemoryBase output_data, - const dnn::ConvolutionDescriptor& convolution_descriptor, - ScratchAllocator* scratch_allocator, - std::vector* out_algorithms); - - bool GetMIOpenConvolveAlgorithmsFindMode( - dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream, - const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data, - const dnn::FilterDescriptor& filter_descriptor, - DeviceMemoryBase filter_data, - const dnn::BatchDescriptor& output_descriptor, - DeviceMemoryBase output_data, - const dnn::ConvolutionDescriptor& convolution_descriptor, - ScratchAllocator* scratch_allocator, - std::vector* out_algorithms); - MIOpenSupport(const MIOpenSupport&) = delete; void operator=(const MIOpenSupport&) = delete; }; diff --git a/xla/stream_executor/rocm/rocm_helpers.cu.cc b/xla/stream_executor/rocm/rocm_helpers.cu.cc index cf12ffdaeb47c..7f3eda32edd55 100644 --- a/xla/stream_executor/rocm/rocm_helpers.cu.cc +++ b/xla/stream_executor/rocm/rocm_helpers.cu.cc @@ -83,16 +83,27 @@ __device__ float sigmoid(float x) { return __expf(x) / (__expf(x) + 1.); } -template +template __global__ void launchInplaceBiasActivation_kernel(T* c_data, - const T* bias_data, + const Tbias* bias_data, + const T* side_input_data, float side_input_scale, uint64_t m, uint64_t n, - int64_t ldc, float param) { + int64_t ldc, float param, + int transpose) { uint64_t x = threadIdx.x + blockIdx.x * blockDim.x; uint64_t y = threadIdx.y + blockIdx.y * blockDim.y; + uint64_t z = blockIdx.z; if (x >= n || y >= m) return; - float v = static_cast(c_data[x + y * ldc]) + + float v; + uint64_t addr = x+y*ldc + z * m * n; + if(!transpose) + v = static_cast(c_data[addr]) + static_cast(bias_data[x]); + else + v = static_cast(c_data[addr]) + + static_cast(bias_data[y]); + if(side_input_data != 0) + v += float(side_input_data[addr]) * side_input_scale; if (act_mode == 1) v = sigmoid(v); else if (act_mode == 2) @@ -111,58 +122,60 @@ __global__ void launchInplaceBiasActivation_kernel(T* c_data, v = v > 0.0f ? v : param * v; else if (act_mode == 9) v = 0.5 * v * (1 + erf(v / sqrt(2.0f))); - c_data[x + y * ldc] = (T)v; + c_data[addr] = (T)v; } -template +template void launchInplaceBiasActivation(hipStream_t stream, void* c_data, - const void* bias_data, int activation_mode, - uint64_t m, uint64_t n, int64_t ldc, + const void* bias_data, + const void* side_input_data, float side_input_scale, + int activation_mode, + uint64_t batch, uint64_t m, uint64_t n, int64_t ldc, float param) { uint64_t bx = min(n, static_cast(256)); uint64_t by = min(m, static_cast(256) / bx); uint64_t gx = (n + bx - 1) / bx; uint64_t gy = (m + by - 1) / by; - auto kernel = launchInplaceBiasActivation_kernel; + int transpose = (activation_mode >= 10); + activation_mode %= 10; + auto kernel = launchInplaceBiasActivation_kernel; if (activation_mode == 1) - kernel = launchInplaceBiasActivation_kernel; + kernel = launchInplaceBiasActivation_kernel; else if (activation_mode == 2) - kernel = launchInplaceBiasActivation_kernel; + kernel = launchInplaceBiasActivation_kernel; else if (activation_mode == 3) - kernel = launchInplaceBiasActivation_kernel; + kernel = launchInplaceBiasActivation_kernel; else if (activation_mode == 4) - kernel = launchInplaceBiasActivation_kernel; + kernel = launchInplaceBiasActivation_kernel; else if (activation_mode == 5) - kernel = launchInplaceBiasActivation_kernel; + kernel = launchInplaceBiasActivation_kernel; else if (activation_mode == 6) - kernel = launchInplaceBiasActivation_kernel; + kernel = launchInplaceBiasActivation_kernel; else if (activation_mode == 7) - kernel = launchInplaceBiasActivation_kernel; + kernel = launchInplaceBiasActivation_kernel; else if (activation_mode == 8) - kernel = launchInplaceBiasActivation_kernel; + kernel = launchInplaceBiasActivation_kernel; else if (activation_mode == 9) - kernel = launchInplaceBiasActivation_kernel; + kernel = launchInplaceBiasActivation_kernel; - hipLaunchKernelGGL(kernel, dim3(gx, gy, 1), dim3(bx, by, 1), 0, stream, - static_cast(c_data), static_cast(bias_data), - m, n, ldc, param); + hipLaunchKernelGGL(kernel, dim3(gx, gy, batch), dim3(bx, by, 1), 0, stream, + static_cast(c_data), static_cast(bias_data), + static_cast(side_input_data), side_input_scale, + m, n, ldc, param, transpose); } -template void launchInplaceBiasActivation<__half>( - hipStream_t stream, void* c_data, const void* bias_data, - int activation_mode, uint64_t m, uint64_t n, int64_t ldc, float param); - -template void launchInplaceBiasActivation( - hipStream_t stream, void* c_data, const void* bias_data, - int activation_mode, uint64_t m, uint64_t n, int64_t ldc, float param); - -template void launchInplaceBiasActivation( - hipStream_t stream, void* c_data, const void* bias_data, - int activation_mode, uint64_t m, uint64_t n, int64_t ldc, float param); - -template void launchInplaceBiasActivation( - hipStream_t stream, void* c_data, const void* bias_data, - int activation_mode, uint64_t m, uint64_t n, int64_t ldc, float param); +#define INSTANTIATE_BIAS_ACTIVATION(X, Y) \ +template void launchInplaceBiasActivation( \ + hipStream_t stream, void* c_data, const void* bias_data, \ + const void* side_input_data, float side_input_scale, \ + int activation_mode, uint64_t batch, uint64_t m, uint64_t n, int64_t ldc, float param); + +INSTANTIATE_BIAS_ACTIVATION(__half, __half) +INSTANTIATE_BIAS_ACTIVATION(__half, float) +INSTANTIATE_BIAS_ACTIVATION(hip_bfloat16, hip_bfloat16) +INSTANTIATE_BIAS_ACTIVATION(hip_bfloat16, float) +INSTANTIATE_BIAS_ACTIVATION(float, float) +INSTANTIATE_BIAS_ACTIVATION(double, double) }; // namespace gpu }; // namespace stream_executor diff --git a/xla/tests/convolution_test.cc b/xla/tests/convolution_test.cc index c4f3ec42c1c64..3718edc01c90c 100644 --- a/xla/tests/convolution_test.cc +++ b/xla/tests/convolution_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/strings/str_replace.h" #include "xla/array2d.h" #include "xla/array4d.h" #include "xla/client/global_data.h" @@ -1755,7 +1756,7 @@ ENTRY TestComputation { } XLA_TEST_F(ConvolutionHloTest, TestFusedConv2D) { - constexpr char kHlo[] = R"( + std::string kHlo = R"( HloModule TestModule ENTRY TestComputation { @@ -1765,11 +1766,47 @@ ENTRY TestComputation { %bias = f32[32] parameter(2) %broadcasted_bias = f32[8,5,5,32] broadcast(%bias), dimensions={3} %add = f32[8,5,5,32] add(%conv, %broadcasted_bias) +)"; + + std::string kHloNoPad = R"( +HloModule TestModule + +ENTRY TestComputation { + %p0 = f32[8,7,7,1] parameter(0) + %p1 = f32[3,3,1,32] parameter(1) + %conv = f32[8,5,5,32] convolution(p0, p1), window={size=3x3 pad=0_0x0_0}, dim_labels=b01f_01io->b01f + %bias = f32[32] parameter(2) + %broadcasted_bias = f32[8,5,5,32] broadcast(%bias), dimensions={3} + %add = f32[8,5,5,32] add(%conv, %broadcasted_bias) +)"; + + std::string kHloRELU = R"( + %zero = f32[] constant(0) %zeros = f32[8,5,5,32] broadcast(%zero), dimensions={} ROOT relu = f32[8,5,5,32] maximum(%zeros, %add) })"; - EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.01, 0.01})); + + std::string kHloTANH = R"( + ROOT result = f32[8,5,5,32] tanh(%add) +})"; + + std::string kHloELU = R"( + %zero = f32[] constant(0) + %zeros = f32[8,5,5,32] broadcast(%zero), dimensions={} + %one = f32[] constant(1) + %ones = f32[8,5,5,32] broadcast(%one), dimensions={} + %exp = f32[8,5,5,32] exponential(%add) + %expm1 = f32[8,5,5,32] subtract(%exp, %ones) + %sgn = pred[8,5,5,32] compare(%add, %zeros), direction=GT + ROOT elu = f32[8,5,5,32] select(%sgn, %add, %expm1) +})"; + + EXPECT_TRUE(RunAndCompare(kHlo+kHloRELU, ErrorSpec{0.01, 0.01})); + EXPECT_TRUE(RunAndCompare(kHlo+kHloTANH, ErrorSpec{0.01, 0.01})); + EXPECT_TRUE(RunAndCompare(kHlo+kHloELU, ErrorSpec{0.01, 0.01})); + EXPECT_TRUE(RunAndCompare(absl::StrReplaceAll(kHlo+kHloRELU, {{"f32", "f16"}}), ErrorSpec{0.03, 0.03})); + EXPECT_TRUE(RunAndCompare(absl::StrReplaceAll(kHloNoPad+kHloRELU, {{"f32", "f16"}}), ErrorSpec{0.03, 0.03})); } XLA_TEST_F(ConvolutionHloTest, TestFusedConv3D) { From 115c33f4982ec1e3da01ebdb0be09964bdf4dc9e Mon Sep 17 00:00:00 2001 From: Thomas Joerg Date: Tue, 4 Jun 2024 13:02:57 -0700 Subject: [PATCH 6/7] [XLA:GPU] Pass CUDA and cuDNN versions (or ROCm counterparts) explicitly into CudnnFusedConvRewriter. PiperOrigin-RevId: 640256343 --- xla/service/gpu/BUILD | 8 +- xla/service/gpu/amdgpu_compiler.cc | 3 +- xla/service/gpu/cudnn_fused_conv_rewriter.cc | 34 +++-- xla/service/gpu/cudnn_fused_conv_rewriter.h | 21 ++- .../gpu/cudnn_fused_conv_rewriter_test.cc | 138 ++++++++++++------ xla/service/gpu/nvptx_compiler.cc | 3 +- 6 files changed, 136 insertions(+), 71 deletions(-) diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index c99e3b963100c..c6a23d2287e8f 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -4599,7 +4599,6 @@ cc_library( name = "cudnn_fused_conv_rewriter", srcs = ["cudnn_fused_conv_rewriter.cc"], hdrs = ["cudnn_fused_conv_rewriter.h"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), deps = [ ":backend_configs_cc", ":cublas_cudnn", @@ -4628,10 +4627,7 @@ cc_library( "@tsl//tsl/platform:errors", "@tsl//tsl/platform:ml_dtypes", "@tsl//tsl/platform:statusor", - ] + if_cuda_is_configured([ - "@local_config_cuda//cuda:cuda_headers", - "@local_config_cuda//cuda:cudnn_header", - ]), + ], ) xla_test( @@ -4654,6 +4650,7 @@ xla_test( ":cublas_cudnn", ":cudnn_fused_conv_rewriter", ":gpu_conv_rewriter", + ":stream_executor_util", "//xla:comparison_util", "//xla:error_spec", "//xla/hlo/ir:hlo", @@ -4668,6 +4665,7 @@ xla_test( "//xla/service:reshape_mover", "//xla/service/gpu/tests:gpu_codegen_test", "//xla/stream_executor:device_description", + "//xla/stream_executor:stream_executor_headers", "//xla/tests:filecheck", "//xla/tests:hlo_test_base", "//xla/tests:verified_hlo_module", diff --git a/xla/service/gpu/amdgpu_compiler.cc b/xla/service/gpu/amdgpu_compiler.cc index e47054f633f03..c20a21025db44 100644 --- a/xla/service/gpu/amdgpu_compiler.cc +++ b/xla/service/gpu/amdgpu_compiler.cc @@ -111,7 +111,8 @@ absl::Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization( pipeline.AddPass(); pipeline.AddPass(); auto rcc = std::get(gpu_version); - pipeline.AddPass(rcc); + pipeline.AddPass(rcc, dnn_version, + GetToolkitVersion()); // The conv padding/vectorization passes which we need to get rid of. They // also leave behind unnecessary tuple/get-tuple-element pairs that diff --git a/xla/service/gpu/cudnn_fused_conv_rewriter.cc b/xla/service/gpu/cudnn_fused_conv_rewriter.cc index 8efa13b9ee22e..475fe8e2fb865 100644 --- a/xla/service/gpu/cudnn_fused_conv_rewriter.cc +++ b/xla/service/gpu/cudnn_fused_conv_rewriter.cc @@ -39,28 +39,23 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/comparison_util.h" #include "xla/debug_options_flags.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/literal.h" -#include "xla/shape.h" -#include "xla/shape_util.h" -#include "xla/stream_executor/device_description.h" -#include "xla/util.h" -#include "tsl/platform/ml_dtypes.h" - -#if GOOGLE_CUDA -#include "third_party/gpus/cuda/include/cuda.h" -#include "third_party/gpus/cudnn/cudnn.h" -#endif - -#include "xla/hlo/ir/hlo_instruction.h" #include "xla/primitive_util.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/cublas_cudnn.h" #include "xla/service/hlo_creation_utils.h" #include "xla/service/pattern_matcher.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/dnn.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" +#include "tsl/platform/ml_dtypes.h" #include "tsl/platform/statusor.h" namespace xla { @@ -666,10 +661,17 @@ CaptureConvGraph(HloInstruction* instr, HloInstruction* convolution, // 5. Optionally calculate the maximum of the absolute of the result. // 6. Optionally cast the output back to FP8. absl::StatusOr F8GraphConv(HloComputation* comp, - se::CudaComputeCapability cc) { + se::CudaComputeCapability cc, + se::dnn::VersionInfo dnn_version, + int32_t toolkit_version) { bool changed = false; -#if CUDA_VERSION >= 12000 && CUDNN_VERSION >= 8900 + if (dnn_version < se::dnn::VersionInfo(8, 9, 0)) { + return false; + } + if (toolkit_version < 12000) { + return false; + } if (!cc.IsAtLeast(se::CudaComputeCapability::HOPPER)) { return false; } @@ -767,7 +769,6 @@ absl::StatusOr F8GraphConv(HloComputation* comp, changed = true; } } -#endif // CUDA_VERSION >= 12000 && CUDNN_VERSION >= 8900 return changed; } @@ -1493,7 +1494,8 @@ absl::StatusOr CudnnFusedConvRewriter::Run( // ForwardGraph Custom Call. if(!IsROCm(compute_capability_)) { auto cc = std::get(compute_capability_); - TF_ASSIGN_OR_RETURN(changed, F8GraphConv(comp, cc)); + TF_ASSIGN_OR_RETURN( + changed, F8GraphConv(comp, cc, dnn_version_, toolkit_version_)); if (changed) { return changed; } diff --git a/xla/service/gpu/cudnn_fused_conv_rewriter.h b/xla/service/gpu/cudnn_fused_conv_rewriter.h index ff1d156525539..801d7b29d384f 100644 --- a/xla/service/gpu/cudnn_fused_conv_rewriter.h +++ b/xla/service/gpu/cudnn_fused_conv_rewriter.h @@ -16,12 +16,15 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_CUDNN_FUSED_CONV_REWRITER_H_ #define XLA_SERVICE_GPU_CUDNN_FUSED_CONV_REWRITER_H_ +#include + #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_pass_interface.h" #include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/dnn.h" namespace xla { namespace gpu { @@ -98,10 +101,18 @@ namespace gpu { // pass returns an error -- cudnn will not be able to run it. class CudnnFusedConvRewriter : public HloModulePass { public: - explicit CudnnFusedConvRewriter(se::CudaComputeCapability cc) - : compute_capability_(cc) {} - explicit CudnnFusedConvRewriter(se::RocmComputeCapability cc) - : compute_capability_(cc) {} + CudnnFusedConvRewriter(se::CudaComputeCapability cc, + se::dnn::VersionInfo dnn_version, + int32_t toolkit_version) + : compute_capability_(cc), + dnn_version_(dnn_version), + toolkit_version_(toolkit_version) {} + CudnnFusedConvRewriter(se::RocmComputeCapability cc, + se::dnn::VersionInfo dnn_version, + int32_t toolkit_version) + : compute_capability_(cc), + dnn_version_(dnn_version), + toolkit_version_(toolkit_version) {} absl::string_view name() const override { return "cudnn-fused-convolution-rewriter"; @@ -114,6 +125,8 @@ class CudnnFusedConvRewriter : public HloModulePass { private: const se::GpuComputeCapability compute_capability_; + const se::dnn::VersionInfo dnn_version_; + const int32_t toolkit_version_; }; } // namespace gpu diff --git a/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc b/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc index 2add39d4a3b1b..12dcc9fd5aaf8 100644 --- a/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc +++ b/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc @@ -33,8 +33,10 @@ limitations under the License. #include "xla/error_spec.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/gpu/stream_executor_util.h" #include "xla/service/hlo_module_config.h" #include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/dnn.h" #include "xla/tests/verified_hlo_module.h" #include "tsl/platform/statusor.h" @@ -85,7 +87,9 @@ class CudnnFusedConvRewriterHloTest : public HloTestBase { ->GetDeviceDescription() .cuda_compute_capability(); } - + stream_executor::dnn::VersionInfo GetDnnVersion() { + return GetDnnVersionInfoOrDefault(backend().default_stream_executor()); + } CudnnFusedConvRewriterHloTest() : HloTestBase(/*verifier_layout_sensitive=*/false, /*allow_mixed_precision_in_hlo_verifier=*/false, @@ -105,6 +109,9 @@ class CudnnFusedConvRewriterTest : public GpuCodegenTest { ->GetDeviceDescription() .cuda_compute_capability(); } + stream_executor::dnn::VersionInfo GetDnnVersion() { + return GetDnnVersionInfoOrDefault(backend().default_stream_executor()); + } protected: std::string GetOptimizedHlo(absl::string_view hlo_string) { @@ -225,12 +232,14 @@ class CudnnFusedConvRewriterTest : public GpuCodegenTest { RunAndFilecheckHloRewrite( module->ToString(HloPrintOptions{}.set_print_operand_shape(false)), CudnnFusedConvRewriter( - se::CudaComputeCapability{se::CudaComputeCapability::HOPPER, 0}), + se::CudaComputeCapability{se::CudaComputeCapability::HOPPER, 0}, + GetDnnVersion(), CUDA_VERSION), custom_call_string); RunAndFilecheckHloRewrite( module->ToString(HloPrintOptions{}.set_print_operand_shape(false)), CudnnFusedConvRewriter( - se::CudaComputeCapability{se::CudaComputeCapability::HOPPER, 0}), + se::CudaComputeCapability{se::CudaComputeCapability::HOPPER, 0}, + GetDnnVersion(), CUDA_VERSION), serialized_graph_string); } } @@ -1284,7 +1293,8 @@ TEST_F(CudnnFusedConvRewriterHloTest, TestConvInt8ToFloat) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability()}; + CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), + CUDA_VERSION}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -1317,7 +1327,8 @@ TEST_F(CudnnFusedConvRewriterHloTest, TestConvInt8ToInt8BiasSideInput) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability()}; + CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), + CUDA_VERSION}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); // Simplify new `convert`'s that may be added to the graph. @@ -1357,7 +1368,8 @@ TEST_F(CudnnFusedConvRewriterHloTest, TestReluAfterConvert) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability()}; + CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), + CUDA_VERSION}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); // Simplify new `convert`'s that may be added to the graph. @@ -1410,7 +1422,8 @@ TEST_F(CudnnFusedConvRewriterHloTest, TestConvInt8ToFloatBiasSideInput) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability()}; + CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), + CUDA_VERSION}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); // Simplify new `convert`'s that may be added to the graph. @@ -1455,7 +1468,8 @@ TEST_F(CudnnFusedConvRewriterHloTest, Int8SideInputWithScaleAndReshape) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability()}; + CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), + CUDA_VERSION}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); // Simplify new `convert`'s that may be added to the graph. @@ -1509,7 +1523,8 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseAlpha) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability()}; + CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), + CUDA_VERSION}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -1548,7 +1563,8 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseRelu) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability()}; + CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), + CUDA_VERSION}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -1588,7 +1604,8 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseReluIfMultipleUses) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability()}; + CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), + CUDA_VERSION}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -1638,7 +1655,8 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseElu) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); // elu fusion is only active on Ampere+. - CudnnFusedConvRewriter fuser{se::CudaComputeCapability(8, 0)}; + CudnnFusedConvRewriter fuser{se::CudaComputeCapability(8, 0), GetDnnVersion(), + CUDA_VERSION}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -1685,7 +1703,8 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseEluIfMultipleUses) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability()}; + CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), + CUDA_VERSION}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -1738,7 +1757,8 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseRelu6) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); // relu6 fusion is only enabled on Ampere+. - CudnnFusedConvRewriter fuser{se::CudaComputeCapability(8, 0)}; + CudnnFusedConvRewriter fuser{se::CudaComputeCapability(8, 0), GetDnnVersion(), + CUDA_VERSION}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); const HloInstruction* conv; @@ -1780,7 +1800,8 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseRelu6IfMultipleUses) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability()}; + CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), + CUDA_VERSION}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -1828,7 +1849,8 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseLeakyRelu) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); // Leaky-relu fusion is only enabled on Ampere+. - CudnnFusedConvRewriter fuser{se::CudaComputeCapability(8, 0)}; + CudnnFusedConvRewriter fuser{se::CudaComputeCapability(8, 0), GetDnnVersion(), + CUDA_VERSION}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -1873,7 +1895,8 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseLeakyReluIfMultipleUses) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability()}; + CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), + CUDA_VERSION}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -1920,7 +1943,8 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseAlphaIfMultipleUsers) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability()}; + CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), + CUDA_VERSION}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -1959,7 +1983,8 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseBiasIfMultipleUsers) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability()}; + CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), + CUDA_VERSION}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -1997,7 +2022,8 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseSideInputThroughRelu) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability()}; + CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), + CUDA_VERSION}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -2035,7 +2061,8 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseBiasThroughRelu) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability()}; + CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), + CUDA_VERSION}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -2070,7 +2097,8 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseSideInputIfMultipleUsers) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability()}; + CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), + CUDA_VERSION}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -2106,7 +2134,8 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseConvertToF16IfMultipleUsers) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability()}; + CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), + CUDA_VERSION}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -2139,7 +2168,8 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseToS8IfMultipleUsers) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability()}; + CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), + CUDA_VERSION}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -2171,7 +2201,8 @@ TEST_F(CudnnFusedConvRewriterHloTest, RemoveConvertByFusingS32ToF32) { TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability()}; + CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), + CUDA_VERSION}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); HloInstruction* conv1 = nullptr; @@ -2197,7 +2228,8 @@ TEST_F(CudnnFusedConvRewriterHloTest, RemoveConvertByFusingS8ToF32) { TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability()}; + CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), + CUDA_VERSION}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); HloInstruction* conv1 = nullptr; @@ -2223,7 +2255,8 @@ TEST_F(CudnnFusedConvRewriterHloTest, RemoveConvertByFusingF32ToS8) { TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability()}; + CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), + CUDA_VERSION}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); HloInstruction* conv1 = nullptr; @@ -2250,7 +2283,8 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontRemoveConvertDuetoMultpleUser) { TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability()}; + CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), + CUDA_VERSION}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); HloInstruction* conv1 = nullptr; @@ -2279,7 +2313,8 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseBias) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability()}; + CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), + CUDA_VERSION}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -2310,7 +2345,8 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseSideInput) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability()}; + CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), + CUDA_VERSION}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -2352,7 +2388,8 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseScaledSideInput) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability()}; + CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), + CUDA_VERSION}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -2394,7 +2431,8 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseBiasAndSideInput) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability()}; + CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), + CUDA_VERSION}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -2431,7 +2469,8 @@ TEST_F(CudnnFusedConvRewriterHloTest, EffectiveScalarBias) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability()}; + CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), + CUDA_VERSION}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -2473,7 +2512,8 @@ TEST_F(CudnnFusedConvRewriterHloTest, StrengthReduceF32ToF16) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability()}; + CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), + CUDA_VERSION}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); // Simplify new `convert`'s that may be added to the graph. @@ -2519,7 +2559,8 @@ TEST_F(CudnnFusedConvRewriterHloTest, BroadcastReshapeTransposeAfterConvert) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability()}; + CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), + CUDA_VERSION}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); // Simplify new `convert`'s that may be added to the graph. @@ -2571,7 +2612,8 @@ TEST_F(CudnnFusedConvRewriterHloTest, NoStrengthReduceF32ToF16IfBiasIsF32) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability()}; + CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), + CUDA_VERSION}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); // Simplify new `convert`'s that may be added to the graph. @@ -2626,7 +2668,8 @@ TEST_F(CudnnFusedConvRewriterHloTest, F32Constants) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability()}; + CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), + CUDA_VERSION}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); // Simplify new `convert`'s that may be added to the graph, and fold @@ -2679,7 +2722,8 @@ TEST_F(CudnnFusedConvRewriterHloTest, F32ConstantsNotLosslesslyConvertible) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability()}; + CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), + CUDA_VERSION}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); // Simplify new `convert`'s that may be added to the graph, and fold @@ -2742,7 +2786,8 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseReluBeforeConvert) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability()}; + CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), + CUDA_VERSION}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); // Simplify new `convert`'s that may be added to the graph. @@ -2784,7 +2829,8 @@ TEST_F(CudnnFusedConvRewriterHloTest, BiasTypeMatchesConvTypeIfFp) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); - CudnnFusedConvRewriter fuser{GetCudaComputeCapability()}; + CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), + CUDA_VERSION}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); // Simplify new `convert`'s that may be added to the graph. @@ -3059,8 +3105,10 @@ TEST_F(CudnnFusedConvRewriterTest, TestConvInt8ToInt8NoClamp) { })"); TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - ASSERT_FALSE( - CudnnFusedConvRewriter(GetCudaComputeCapability()).Run(m.get()).ok()); + ASSERT_FALSE(CudnnFusedConvRewriter(GetCudaComputeCapability(), + GetDnnVersion(), CUDA_VERSION) + .Run(m.get()) + .ok()); } TEST_F(CudnnFusedConvRewriterTest, TestFusedConvInt8ToInt8NoClamp) { @@ -3085,8 +3133,10 @@ TEST_F(CudnnFusedConvRewriterTest, TestFusedConvInt8ToInt8NoClamp) { })"); TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - ASSERT_FALSE( - CudnnFusedConvRewriter(GetCudaComputeCapability()).Run(m.get()).ok()); + ASSERT_FALSE(CudnnFusedConvRewriter(GetCudaComputeCapability(), + GetDnnVersion(), CUDA_VERSION) + .Run(m.get()) + .ok()); } } // namespace diff --git a/xla/service/gpu/nvptx_compiler.cc b/xla/service/gpu/nvptx_compiler.cc index f1224ff3f3809..5fa42c0e26fa1 100644 --- a/xla/service/gpu/nvptx_compiler.cc +++ b/xla/service/gpu/nvptx_compiler.cc @@ -198,7 +198,8 @@ absl::Status NVPTXCompiler::OptimizeHloConvolutionCanonicalization( pipeline.AddPass(); pipeline.AddPass(); - pipeline.AddPass(cuda_compute_capability); + pipeline.AddPass(cuda_compute_capability, dnn_version, + GetToolkitVersion()); pipeline.AddPass(); pipeline.AddPass(cuda_compute_capability); pipeline.AddPass(cuda_compute_capability, From 087578f2610a0411f0a074ae8db442172d94599f Mon Sep 17 00:00:00 2001 From: Harsha HS Date: Wed, 12 Jun 2024 06:16:26 -0700 Subject: [PATCH 7/7] [ROCm] Fix build break of cudnn_fused_conv_rewriter_test due to 1268712 --- xla/service/gpu/BUILD | 5 +- xla/service/gpu/amdgpu_compiler.cc | 2 +- .../gpu/cudnn_fused_conv_rewriter_test.cc | 110 +++++++++++------- 3 files changed, 71 insertions(+), 46 deletions(-) diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index c6a23d2287e8f..49969912ed7c0 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -4643,7 +4643,8 @@ xla_test( backends = [ "gpu_a100", ] + if_oss(["gpu"]), - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + + if_rocm_is_configured(["TENSORFLOW_USE_ROCM=1"]), shard_count = 10, deps = [ ":backend_configs_cc", @@ -4680,6 +4681,8 @@ xla_test( ] + if_cuda_is_configured([ "@local_config_cuda//cuda:cuda_headers", "@local_config_cuda//cuda:cudnn_header", + ]) + if_rocm_is_configured([ + "@local_config_rocm//rocm:rocm_headers" ]), ) diff --git a/xla/service/gpu/amdgpu_compiler.cc b/xla/service/gpu/amdgpu_compiler.cc index c20a21025db44..d6f46ac3e372b 100644 --- a/xla/service/gpu/amdgpu_compiler.cc +++ b/xla/service/gpu/amdgpu_compiler.cc @@ -112,7 +112,7 @@ absl::Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization( pipeline.AddPass(); auto rcc = std::get(gpu_version); pipeline.AddPass(rcc, dnn_version, - GetToolkitVersion()); + 0); // The conv padding/vectorization passes which we need to get rid of. They // also leave behind unnecessary tuple/get-tuple-element pairs that diff --git a/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc b/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc index 12dcc9fd5aaf8..9c6c78385f174 100644 --- a/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc +++ b/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc @@ -42,7 +42,9 @@ limitations under the License. #if GOOGLE_CUDA #include "third_party/gpus/cuda/include/cuda.h" -#endif +#elif TENSORFLOW_USE_ROCM +#include "rocm/rocm_config.h" +#endif //GOOGLE_CUDA #include "xla/service/algebraic_simplifier.h" #include "xla/service/convert_mover.h" @@ -87,9 +89,20 @@ class CudnnFusedConvRewriterHloTest : public HloTestBase { ->GetDeviceDescription() .cuda_compute_capability(); } + stream_executor::dnn::VersionInfo GetDnnVersion() { - return GetDnnVersionInfoOrDefault(backend().default_stream_executor()); + return stream_executor::dnn::VersionInfo((8,9,4)); + } + + int32_t GetToolkitVersion() const { +#if GOOGLE_CUDA + return CUDA_VERSION; +#elif TENSORFLOW_USE_ROCM + return TF_ROCM_VERSION; +#endif + return 0; } + CudnnFusedConvRewriterHloTest() : HloTestBase(/*verifier_layout_sensitive=*/false, /*allow_mixed_precision_in_hlo_verifier=*/false, @@ -110,7 +123,16 @@ class CudnnFusedConvRewriterTest : public GpuCodegenTest { .cuda_compute_capability(); } stream_executor::dnn::VersionInfo GetDnnVersion() { - return GetDnnVersionInfoOrDefault(backend().default_stream_executor()); + return stream_executor::dnn::VersionInfo((8,9,4)); + } + + int32_t GetToolkitVersion() const { +#if GOOGLE_CUDA + return CUDA_VERSION; +#elif TENSORFLOW_USE_ROCM + return TF_ROCM_VERSION; +#endif + return 0; } protected: @@ -233,13 +255,13 @@ class CudnnFusedConvRewriterTest : public GpuCodegenTest { module->ToString(HloPrintOptions{}.set_print_operand_shape(false)), CudnnFusedConvRewriter( se::CudaComputeCapability{se::CudaComputeCapability::HOPPER, 0}, - GetDnnVersion(), CUDA_VERSION), + GetDnnVersion(), GetToolkitVersion()), custom_call_string); RunAndFilecheckHloRewrite( module->ToString(HloPrintOptions{}.set_print_operand_shape(false)), CudnnFusedConvRewriter( se::CudaComputeCapability{se::CudaComputeCapability::HOPPER, 0}, - GetDnnVersion(), CUDA_VERSION), + GetDnnVersion(), GetToolkitVersion()), serialized_graph_string); } } @@ -1294,7 +1316,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, TestConvInt8ToFloat) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - CUDA_VERSION}; + GetToolkitVersion()}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -1328,7 +1350,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, TestConvInt8ToInt8BiasSideInput) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - CUDA_VERSION}; + GetToolkitVersion()}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); // Simplify new `convert`'s that may be added to the graph. @@ -1369,7 +1391,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, TestReluAfterConvert) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - CUDA_VERSION}; + GetToolkitVersion()}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); // Simplify new `convert`'s that may be added to the graph. @@ -1423,7 +1445,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, TestConvInt8ToFloatBiasSideInput) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - CUDA_VERSION}; + GetToolkitVersion()}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); // Simplify new `convert`'s that may be added to the graph. @@ -1469,7 +1491,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, Int8SideInputWithScaleAndReshape) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - CUDA_VERSION}; + GetToolkitVersion()}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); // Simplify new `convert`'s that may be added to the graph. @@ -1524,7 +1546,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseAlpha) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - CUDA_VERSION}; + GetToolkitVersion()}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -1564,7 +1586,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseRelu) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - CUDA_VERSION}; + GetToolkitVersion()}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -1605,7 +1627,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseReluIfMultipleUses) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - CUDA_VERSION}; + GetToolkitVersion()}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -1656,7 +1678,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseElu) { TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); // elu fusion is only active on Ampere+. CudnnFusedConvRewriter fuser{se::CudaComputeCapability(8, 0), GetDnnVersion(), - CUDA_VERSION}; + GetToolkitVersion()}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -1704,7 +1726,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseEluIfMultipleUses) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - CUDA_VERSION}; + GetToolkitVersion()}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -1758,7 +1780,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseRelu6) { TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); // relu6 fusion is only enabled on Ampere+. CudnnFusedConvRewriter fuser{se::CudaComputeCapability(8, 0), GetDnnVersion(), - CUDA_VERSION}; + GetToolkitVersion()}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); const HloInstruction* conv; @@ -1801,7 +1823,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseRelu6IfMultipleUses) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - CUDA_VERSION}; + GetToolkitVersion()}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -1850,7 +1872,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseLeakyRelu) { TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); // Leaky-relu fusion is only enabled on Ampere+. CudnnFusedConvRewriter fuser{se::CudaComputeCapability(8, 0), GetDnnVersion(), - CUDA_VERSION}; + GetToolkitVersion()}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -1896,7 +1918,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseLeakyReluIfMultipleUses) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - CUDA_VERSION}; + GetToolkitVersion()}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -1944,7 +1966,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseAlphaIfMultipleUsers) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - CUDA_VERSION}; + GetToolkitVersion()}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -1984,7 +2006,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseBiasIfMultipleUsers) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - CUDA_VERSION}; + GetToolkitVersion()}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -2023,7 +2045,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseSideInputThroughRelu) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - CUDA_VERSION}; + GetToolkitVersion()}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -2062,7 +2084,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseBiasThroughRelu) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - CUDA_VERSION}; + GetToolkitVersion()}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -2098,7 +2120,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseSideInputIfMultipleUsers) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - CUDA_VERSION}; + GetToolkitVersion()}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -2135,7 +2157,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseConvertToF16IfMultipleUsers) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - CUDA_VERSION}; + GetToolkitVersion()}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -2169,7 +2191,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseToS8IfMultipleUsers) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - CUDA_VERSION}; + GetToolkitVersion()}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -2202,7 +2224,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, RemoveConvertByFusingS32ToF32) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - CUDA_VERSION}; + GetToolkitVersion()}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); HloInstruction* conv1 = nullptr; @@ -2229,7 +2251,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, RemoveConvertByFusingS8ToF32) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - CUDA_VERSION}; + GetToolkitVersion()}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); HloInstruction* conv1 = nullptr; @@ -2256,7 +2278,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, RemoveConvertByFusingF32ToS8) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - CUDA_VERSION}; + GetToolkitVersion()}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); HloInstruction* conv1 = nullptr; @@ -2284,7 +2306,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontRemoveConvertDuetoMultpleUser) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - CUDA_VERSION}; + GetToolkitVersion()}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); HloInstruction* conv1 = nullptr; @@ -2314,7 +2336,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseBias) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - CUDA_VERSION}; + GetToolkitVersion()}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -2346,7 +2368,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseSideInput) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - CUDA_VERSION}; + GetToolkitVersion()}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -2389,7 +2411,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseScaledSideInput) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - CUDA_VERSION}; + GetToolkitVersion()}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -2432,7 +2454,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseBiasAndSideInput) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - CUDA_VERSION}; + GetToolkitVersion()}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -2470,7 +2492,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, EffectiveScalarBias) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - CUDA_VERSION}; + GetToolkitVersion()}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); SCOPED_TRACE(m->ToString()); @@ -2513,7 +2535,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, StrengthReduceF32ToF16) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - CUDA_VERSION}; + GetToolkitVersion()}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); // Simplify new `convert`'s that may be added to the graph. @@ -2560,7 +2582,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, BroadcastReshapeTransposeAfterConvert) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - CUDA_VERSION}; + GetToolkitVersion()}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); // Simplify new `convert`'s that may be added to the graph. @@ -2613,7 +2635,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, NoStrengthReduceF32ToF16IfBiasIsF32) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - CUDA_VERSION}; + GetToolkitVersion()}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); // Simplify new `convert`'s that may be added to the graph. @@ -2669,7 +2691,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, F32Constants) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - CUDA_VERSION}; + GetToolkitVersion()}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); // Simplify new `convert`'s that may be added to the graph, and fold @@ -2723,7 +2745,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, F32ConstantsNotLosslesslyConvertible) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - CUDA_VERSION}; + GetToolkitVersion()}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); // Simplify new `convert`'s that may be added to the graph, and fold @@ -2787,7 +2809,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseReluBeforeConvert) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - CUDA_VERSION}; + GetToolkitVersion()}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); // Simplify new `convert`'s that may be added to the graph. @@ -2830,7 +2852,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, BiasTypeMatchesConvTypeIfFp) { GpuConvRewriter rewriter; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), - CUDA_VERSION}; + GetToolkitVersion()}; TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status()); // Simplify new `convert`'s that may be added to the graph. @@ -3106,7 +3128,7 @@ TEST_F(CudnnFusedConvRewriterTest, TestConvInt8ToInt8NoClamp) { TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); ASSERT_FALSE(CudnnFusedConvRewriter(GetCudaComputeCapability(), - GetDnnVersion(), CUDA_VERSION) + GetDnnVersion(), GetToolkitVersion()) .Run(m.get()) .ok()); } @@ -3134,7 +3156,7 @@ TEST_F(CudnnFusedConvRewriterTest, TestFusedConvInt8ToInt8NoClamp) { TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); ASSERT_FALSE(CudnnFusedConvRewriter(GetCudaComputeCapability(), - GetDnnVersion(), CUDA_VERSION) + GetDnnVersion(), GetToolkitVersion()) .Run(m.get()) .ok()); }