From dd8fe363e560c3f2cf19eadd10278ce51f5a181b Mon Sep 17 00:00:00 2001 From: jichang Date: Thu, 31 Oct 2024 13:18:49 +0000 Subject: [PATCH] FIx test failures and add more tests for c_equal_d --- clients/gtest/matmul_gtest.yaml | 4 +++- clients/include/testing_matmul.hpp | 14 ++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/clients/gtest/matmul_gtest.yaml b/clients/gtest/matmul_gtest.yaml index db6d04697c..b5f04a57f7 100755 --- a/clients/gtest/matmul_gtest.yaml +++ b/clients/gtest/matmul_gtest.yaml @@ -1593,9 +1593,10 @@ Tests: algo_method: [0,1] transA_transB: *transA_transB_range alpha: 1 - beta: 0 + beta: [0,1] requested_solution_num: -1 unit_check: 1 + c_equal_d: [0, 1] - name: matmul_heuristic_all_solutions_real_1byte category: nightly @@ -1611,4 +1612,5 @@ Tests: requested_solution_num: -1 unit_check: 1 gpu_arch: '94[0-2]' + c_equal_d: [0, 1] ... diff --git a/clients/include/testing_matmul.hpp b/clients/include/testing_matmul.hpp index 56a7a51464..11f0912a65 100644 --- a/clients/include/testing_matmul.hpp +++ b/clients/include/testing_matmul.hpp @@ -2761,6 +2761,13 @@ void testing_matmul_with_bias(const Arguments& arg, { for(size_t sol = 0; sol < heuristicResult.size(); sol++) { + if((arg.unit_check || arg.norm_check || arg.allclose_check) && arg.c_equal_d) + { + for(int i = 0; i < gemm_count; i++) + { + CHECK_HIP_ERROR(synchronize(dC[i], hC[i], block_count)); + } + } if(!do_grouped_gemm) { if(arg.use_ext) @@ -2928,6 +2935,13 @@ void testing_matmul_with_bias(const Arguments& arg, for(size_t sol = 0; sol < heuristicResult.size(); sol++) { + if((arg.unit_check || arg.norm_check || arg.allclose_check) && arg.c_equal_d) + { + for(int i = 0; i < gemm_count; i++) + { + CHECK_HIP_ERROR(synchronize(dC[i], hC[i], block_count)); + } + } if(!do_grouped_gemm) { FrequencyMonitor& freq_monitor = getFrequencyMonitor();