diff --git a/build_tools/rocm/run_xla.sh b/build_tools/rocm/run_xla.sh index bc2add44570ab..7fd68b6e35431 100755 --- a/build_tools/rocm/run_xla.sh +++ b/build_tools/rocm/run_xla.sh @@ -60,6 +60,7 @@ bazel \ --test_sharding_strategy=disabled \ --test_output=errors \ --keep_going \ + --flaky_test_attempts=3 \ --local_test_jobs=${N_TEST_JOBS} \ --test_env=TF_TESTS_PER_GPU=$TF_TESTS_PER_GPU \ --test_env=TF_GPU_COUNT=$TF_GPU_COUNT \ diff --git a/xla/service/gpu/tests/gemm_rewrite_test.cc b/xla/service/gpu/tests/gemm_rewrite_test.cc index c3b43726f0e1d..a6e8e62d51131 100644 --- a/xla/service/gpu/tests/gemm_rewrite_test.cc +++ b/xla/service/gpu/tests/gemm_rewrite_test.cc @@ -222,8 +222,7 @@ ENTRY bf16gemm { } )"; - if (!IsCuda() || - HasCudaComputeCapability(se::CudaComputeCapability::Hopper())) { + if (HasCudaComputeCapability(se::CudaComputeCapability::Hopper())) { // The Hopper optimized HLO has a BF16 multiply instruction since Hopper has // native BF16 multiply support. MatchOptimizedHlo(hlo_text, R"( @@ -789,6 +788,41 @@ ENTRY AddDotsFunc { )"); } +TEST_P(ParameterizedGemmRewriteTest, F64C64_CublasLtSupportTest) { + // This test should fail if gemm rewriter does not correctly rewrite + // F64/C64 dots to cublas-lt or legacy cublas calls + { + const char* hlo_text = R"( +HloModule F64_rewrite + +ENTRY AddDotsFunc { + x = f64[2,2] parameter(0) + y = f64[2,2] parameter(1) + k = f64[] constant(3.0) + k_broadcast = f64[2, 2] broadcast(k), dimensions={} + dot_a = f64[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT dot_a_multiplied = f64[2, 2] multiply(dot_a, k_broadcast) +} +)"; + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-4, 1e-5})); + } + { + const char* hlo_text = R"( +HloModule C64_rewrite + +ENTRY AddDotsFunc { + x = c64[2,2] parameter(0) + y = c64[2,2] parameter(1) + k = c64[] constant((3.0, 3.0)) + k_broadcast = c64[2, 2] broadcast(k), dimensions={} + dot_a = c64[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT dot_a_multiplied = c64[2, 2] multiply(dot_a, k_broadcast) +} +)"; + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-4, 1e-5})); + } +} + TEST_P(ParameterizedGemmRewriteTest, ComplexAlphaSimpleRewrite) { if (!IsCuda() && GetDebugOptionsForTest().xla_gpu_enable_cublaslt()) { GTEST_SKIP() << "TODO: Unsupported C64 gpublas-lt datatype on ROCM"; @@ -7916,6 +7950,7 @@ class GemmRewriteAllocationTest : public GpuCodegenTest { DebugOptions GetDebugOptionsForTest() override { DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest(); // Make sure the rewriter does not skip the rewrite for being too small. + debug_options.set_xla_gpu_enable_triton_gemm(false); debug_options.set_xla_gpu_gemm_rewrite_size_threshold(0); return debug_options; } diff --git a/xla/tests/matmul_test.cc b/xla/tests/matmul_test.cc index 61e97d15a3882..668fa32425391 100644 --- a/xla/tests/matmul_test.cc +++ b/xla/tests/matmul_test.cc @@ -51,6 +51,22 @@ class MatmulTestWithCublas : public HloTestBase, const bool use_cublas_lt_{GetParam()}; }; +TEST_P(MatmulTestWithCublas, GemmRewriter_RegressionTestF64) { + const char* module_str = R"( +HloModule GeneralMatMulActivation.7, entry_computation_layout={(f64[2,2,2]{2,1,0}, f64[2,2,2]{2,1,0})->f64[2,2,2]{2,1,0}} + +ENTRY GeneralMatMulActivation.7 { + x.1 = f64[2,2,2]{2,1,0} parameter(0) + y.2 = f64[2,2,2]{2,1,0} parameter(1) + dot.3 = f64[2,2,2]{2,1,0} dot(x.1, y.2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1} + constant.4 = f64[] constant(0) + broadcast.5 = f64[2,2,2]{2,1,0} broadcast(constant.4), dimensions={} + ROOT maximum.6 = f64[2,2,2]{2,1,0} maximum(dot.3, broadcast.5) +})"; + + EXPECT_TRUE(RunAndCompare(module_str, ErrorSpec{1e-4, 1e-4})); +} + // There was an issue where the compilation process of an Inverse operation was // resulting in a cached cuBLASLt matmul plan which was incorrectly fetched at // the time of the Matmul operation