Skip to content

Commit

Permalink
Merge pull request #23 from ROCm/rocm-jaxlib-v0.4.28-qa_fix_gemm_rewr…
Browse files Browse the repository at this point in the history
…ite_test

Fix gemm_rewrite_test and add flaky_test_attempts option
  • Loading branch information
i-chaochen authored Jun 27, 2024
2 parents 6a1f142 + adf7949 commit 5f7bb79
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 2 deletions.
1 change: 1 addition & 0 deletions build_tools/rocm/run_xla.sh
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ bazel \
--test_sharding_strategy=disabled \
--test_output=errors \
--keep_going \
--flaky_test_attempts=3 \
--local_test_jobs=${N_TEST_JOBS} \
--test_env=TF_TESTS_PER_GPU=$TF_TESTS_PER_GPU \
--test_env=TF_GPU_COUNT=$TF_GPU_COUNT \
Expand Down
39 changes: 37 additions & 2 deletions xla/service/gpu/tests/gemm_rewrite_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,7 @@ ENTRY bf16gemm {
}
)";

if (!IsCuda() ||
HasCudaComputeCapability(se::CudaComputeCapability::Hopper())) {
if (HasCudaComputeCapability(se::CudaComputeCapability::Hopper())) {
// The Hopper optimized HLO has a BF16 multiply instruction since Hopper has
// native BF16 multiply support.
MatchOptimizedHlo(hlo_text, R"(
Expand Down Expand Up @@ -789,6 +788,41 @@ ENTRY AddDotsFunc {
)");
}

TEST_P(ParameterizedGemmRewriteTest, F64C64_CublasLtSupportTest) {
// This test should fail if gemm rewriter does not correctly rewrite
// F64/C64 dots to cublas-lt or legacy cublas calls
{
const char* hlo_text = R"(
HloModule F64_rewrite
ENTRY AddDotsFunc {
x = f64[2,2] parameter(0)
y = f64[2,2] parameter(1)
k = f64[] constant(3.0)
k_broadcast = f64[2, 2] broadcast(k), dimensions={}
dot_a = f64[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
ROOT dot_a_multiplied = f64[2, 2] multiply(dot_a, k_broadcast)
}
)";
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-4, 1e-5}));
}
{
const char* hlo_text = R"(
HloModule C64_rewrite
ENTRY AddDotsFunc {
x = c64[2,2] parameter(0)
y = c64[2,2] parameter(1)
k = c64[] constant((3.0, 3.0))
k_broadcast = c64[2, 2] broadcast(k), dimensions={}
dot_a = c64[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
ROOT dot_a_multiplied = c64[2, 2] multiply(dot_a, k_broadcast)
}
)";
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-4, 1e-5}));
}
}

TEST_P(ParameterizedGemmRewriteTest, ComplexAlphaSimpleRewrite) {
if (!IsCuda() && GetDebugOptionsForTest().xla_gpu_enable_cublaslt()) {
GTEST_SKIP() << "TODO: Unsupported C64 gpublas-lt datatype on ROCM";
Expand Down Expand Up @@ -7916,6 +7950,7 @@ class GemmRewriteAllocationTest : public GpuCodegenTest {
DebugOptions GetDebugOptionsForTest() override {
DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest();
// Make sure the rewriter does not skip the rewrite for being too small.
debug_options.set_xla_gpu_enable_triton_gemm(false);
debug_options.set_xla_gpu_gemm_rewrite_size_threshold(0);
return debug_options;
}
Expand Down
16 changes: 16 additions & 0 deletions xla/tests/matmul_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,22 @@ class MatmulTestWithCublas : public HloTestBase,
const bool use_cublas_lt_{GetParam()};
};

TEST_P(MatmulTestWithCublas, GemmRewriter_RegressionTestF64) {
const char* module_str = R"(
HloModule GeneralMatMulActivation.7, entry_computation_layout={(f64[2,2,2]{2,1,0}, f64[2,2,2]{2,1,0})->f64[2,2,2]{2,1,0}}
ENTRY GeneralMatMulActivation.7 {
x.1 = f64[2,2,2]{2,1,0} parameter(0)
y.2 = f64[2,2,2]{2,1,0} parameter(1)
dot.3 = f64[2,2,2]{2,1,0} dot(x.1, y.2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}
constant.4 = f64[] constant(0)
broadcast.5 = f64[2,2,2]{2,1,0} broadcast(constant.4), dimensions={}
ROOT maximum.6 = f64[2,2,2]{2,1,0} maximum(dot.3, broadcast.5)
})";

EXPECT_TRUE(RunAndCompare(module_str, ErrorSpec{1e-4, 1e-4}));
}

// There was an issue where the compilation process of an Inverse operation was
// resulting in a cached cuBLASLt matmul plan which was incorrectly fetched at
// the time of the Matmul operation
Expand Down

0 comments on commit 5f7bb79

Please sign in to comment.